tf.function函数转换

596 阅读2分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第1天,点击查看活动详情

tf.function 可以把普通语法写的python函数变成tensorflow的图

autograph 是 tf.function依赖的一个机制

定义一个普通的python函数

# tf.function and auto-graph.
def scaled_elu(z, scale=1.0, alpha=1.0):
    # z >= 0 ? scale * z : scale * alpha * tf.nn.elu(z)
    is_positive = tf.greater_equal(z, 0.0)
    return scale * tf.where(is_positive, z, alpha * tf.nn.elu(z))

print(scaled_elu(tf.constant(-3.)))
print(scaled_elu(tf.constant([-3., -2.5])))

运行结果:

tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)

使用tf.function将python函数转变成tensorfllow的图

scaled_elu_tf = tf.function(scaled_elu)
print(scaled_elu_tf(tf.constant(-3.)))
print(scaled_elu_tf(tf.constant([-3., -2.5])))

print(scaled_elu_tf.python_function is scaled_elu)

用python_function来转会python函数

tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)
True

转化为图结构的优势:

运算速度快

%timeit scaled_elu(tf.random.normal((1000, 1000)))
%timeit scaled_elu_tf(tf.random.normal((1000, 1000)))
38 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
34.1 ms ± 263 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

可以使用@tf.function将python函数转成图结构

# 1 + 1/2 + 1/2^2 + ... + 1/2^n

@tf.function
def converge_to_2(n_iters):
    total = tf.constant(0.)
    increment = tf.constant(1.)
    for _ in range(n_iters):
        total += increment
        increment /= 2.0
    return total

print(converge_to_2(20))
tf.Tensor(1.9999981, shape=(), dtype=float32)

还可以使用autograph.to_code将python函数转成图结构

def display_tf_code(func):
    code = tf.autograph.to_code(func)
    from IPython.display import display, Markdown
    display(Markdown('```python\n{}\n```'.format(code)))
display_tf_code(scaled_elu)

输出结果:

def tf__scaled_elu(z, scale=None, alpha=None):
    with ag__.FunctionScope('scaled_elu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()
        is_positive = ag__.converted_call(ag__.ld(tf).greater_equal, (ag__.ld(z), 0.0), None, fscope)
        try:
            do_return = True
            retval_ = (ag__.ld(scale) * ag__.converted_call(ag__.ld(tf).where, (ag__.ld(is_positive), ag__.ld(z), (ag__.ld(alpha) * ag__.converted_call(ag__.ld(tf).nn.elu, (ag__.ld(z),), None, fscope))), None, fscope))
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

tf.Variable需要定义在函数的外面

var = tf.Variable(0.)

@tf.function
def add_21():
    return var.assign_add(21) # += 

print(add_21())

使用input_signature限制函数名的类型

@tf.function(input_signature=[tf.TensorSpec([None], tf.int32, name='x')])
def cube(z):
    return tf.pow(z, 3)

try:
    print(cube(tf.constant([1., 2., 3.])))
except ValueError as ex:
    print(ex)
    
print(cube(tf.constant([1, 2, 3])))
Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name='x'))
tf.Tensor([ 1  8 27], shape=(3,), dtype=int32)

@tf.function py func -> tf graph

@tf.function 可以把python中的func转为tf中的graph

get_concrete_function -> add input signature -> SavedModel

get_concrete_function 加上 input signature 可以变成tf中可保存的图结构SaveModel

cube_func_int32 = cube.get_concrete_function(
    tf.TensorSpec([None], tf.int32))
print(cube_func_int32)
ConcreteFunction cube(z)
  Args:
    z: int32 Tensor, shape=(None,)
  Returns:
    int32 Tensor, shape=(None,)
print(cube_func_int32 is cube.get_concrete_function(
    tf.TensorSpec([5], tf.int32)))
print(cube_func_int32 is cube.get_concrete_function(
    tf.constant([1, 2, 3])))
True
True