tf.function (TensorFlow > API > TensorFlow Core v2.2.0 > Python)
tf.function 是 tf 2.x新增的主要功能,函数的装饰器(decorator),将函数编译为可调用的TensorFlow图。
tf.function(
func=None, input_signature=None, autograph=True, experimental_implements=None,
experimental_autograph_options=None, experimental_relax_shapes=False,
experimental_compile=None
)
通过对func
中的TensorFlow操作跟踪编译,创建出一个TensorFlow图(tf.Graph),tf.function
构建一个可调用函数,来执行这个图,从而将func
当作TensorFlow图实现高效的执行。
使用实例,
>>> @tf.function
... def f(x, y):
... return x ** 2 + y
>>> x = tf.constant([2, 3])
>>> y = tf.constant([3, -2])
>>> f(x, y)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
- 特征
func
可以使用数据依赖的控制流,包括if
,for
,while
,break
,continue
,return
语句:
>>> @tf.function
... def f(x):
... if tf.reduce_sum(x) > 0:
... return x * x
... else:
... return -x // 2
>>> f(tf.constant(-2))
<tf.Tensor: shape=(), dtype=int32, numpy=1>
func
函数的闭合可以包含tf.Tensor和tf.Variable对象,
>>> @tf.function
... def f():
... return x ** 2 + y
>>> x = tf.constant([-2, -3])
>>> y = tf.Variable([3, -2])
>>> f()
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
func
也可以使用有副作用的ops,比如 tf.print,tf.Variable等,
>>> v = tf.Variable(1)
>>> @tf.function
... def f(x):
... for i in tf.range(x):
... v.assign_add(i)
>>> f(3)
>>> v
<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=4>
重点:
python的任何副作用(list append,print等)只会在func
被追踪时执行一次。想要在tf.function中执行副作用,需要以TF ops的形式去写这些代码。比如,
>>> l = []
>>> @tf.function
... def f(x):
... for i in x:
... l.append(i + 1) # Caution! Will only happen once when tracing
>>> f(tf.constant([1, 2, 3]))
>>> l
[<tf.Tensor 'add:0' shape=() dtype=int32>]
列表l扩展只会在追踪(图构建)时发生一次,使用TF collections (tf.TensorArray)可以实现每次迭代都运行,
>>> @tf.function
... def f(x):
... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
... for i in range(len(x)):
... ta = ta.write(i, x[i] + 1)
... return ta.stack()
>>> f(tf.constant([1, 2, 3]))
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4])>
tf.function是多态的(tf.function is polymorphic)
Tensorflow建立的指定形状和类型的图会更加高效。对于不同的数据类型和形状的参数,tf.function可以构建多个图,对它们进行支持。tf.function将任何的纯python数值作为未知对象,然后为它所遇到的每个python参数集合都建立一个独立的图。
为了获取一个单独图,使用tf.function创建的get_concrete_function
方法,它可以被与func
相同的参数所调用,返回一个特殊的tf.Graph对象,
>>> @tf.function
... def f(x):
... return x + 1
>>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
True
注意:将python数值或列表作为参数传递给tf.function,tf.function总是会建立新的图。为了避免总是新建图,将数值参数作为Tensor传递:
>>> @tf.function
... def f(x):
... return tf.abs(x)
>>> f1 = f.get_concrete_function(1)
>>> f2 = f.get_concrete_function(2) # Slow - builds new graph
>>> f1 is f2
False
>>> f1 = f.get_concrete_function(tf.constant(1))
>>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1
>>> f1 is f2
True
只又在参数取很少几个不同值的时候,才使用python数值参数,比如超参数:神经网络中的层数。
输入签名(Input signatures)
对于Tensor参数来说,tf.function会为每个独特的输入形状和输入类型的集合,实例化一个单独的图,也就是对于同一类型的输入形状和输入数据类型,只实例化一个图。下买你的例子实例化了2个图,每个都有不同的形状,
>>> @tf.function
... def f(x):
... return x + 1
>>> vector = tf.constant([1.0, 1.0])
>>> matrix = tf.constant([[3.0]])
>>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False
提供给tf.function的输入签名是可选的而非必须的,以控制正在追踪的图。输入签名使用tf.TensorSpec
对象 指定每个Tensor参数的形状和类型,也可以使用更通用的形状。当Tensor具有动态形状时,这可以避免创建多个图。但使用同一个图,同时也限制了可以使用的Tensor大小和数据类型,
>>> @tf.function(
... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
... def f(x):
... return x + 1
>>> vector = tf.constant([1.0, 1.0])
>>> matrix = tf.constant([[3.0]])
>>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True
变量也许只创建一次(Variables may only be created once)
tf.function只允许在它第一次被调用的时候创建tf.Variable对象。
>>> class MyModule(tf.Module):
... def __init__(self):
... self.v = None
... @tf.function
... def call(self, x):
... if self.v is None:
... self.v = tf.Variable(tf.ones_like(x))
... return self.v * x
通常,更推荐的方式是,在tf.function之外创建有状态的对象比如tf.Variable,然后将它们作为参数传递,如:
v = tf.Variable(1.0)
@tf.function
def f(x):
return v.assign_add(x)
print(f(1.0)) # 2.0
print(f(2.0)) # 4.0
Args | |
---|---|
func | the function to be compiled. If func is None, tf.function returns a decorator that can be invoked with a single argument - func . In other words, tf.function(input_signature=...)(func) is equivalent to tf.function(func, input_signature=...) . The former can be used as decorator. |
input_signature | A possibly nested sequence of tf.TensorSpec objects specifying the shapes and dtypes of the Tensors that will be supplied to this function. If None, a separate function is instantiated for each inferred input signature. If input_signature is specified, every input to func must be a Tensor, and func cannot accept **kwargs. |
autograph | Whether autograph should be applied on func before tracing a graph. Data-dependent control flow requires autograph=True. For more information, see the tf.function and AutoGraph guide. |
experimental_implements | If provided, contains a name of a "known" function this implements. For example "mycompany.my_recurrent_cell". This is stored as an attribute in inference function, which can then be detected when processing serialized function. See standardizing composite ops for details. For an example of utilizing this attribute see this example The code above automatically detects and substitutes function that implements "embedded_matmul" and allows TFLite to substitute its own implementations. For instance, a tensorflow user can use this attribute to mark that their function also implements embedded_matmul (perhaps more efficiently!) by specifying it using this parameter: @tf.function(experimental_implements="embedded_matmul") |
experimental_autograph_options | Optional tuple of tf.autograph.experimental.Feature values. |
experimental_relax_shapes | When True, tf.function may generate fewer, graphs that are less specialized on input shapes. |
experimental_compile | If True, the function is always compiled by XLA . XLA may be more efficient in some cases (e.g. TPU, XLA_GPU, dense tensor computations). |
Returns |
---|
If func is not None, returns a callable that will execute the compiled function (and return zero or more tf.Tensor objects). If func is None, returns a decorator that, when invoked with a single func argument, returns a callable equivalent to the case above. |
Raises |
---|
ValueError when attempting to use experimental_compile, but XLA support is not enabled. |