Tensorflow (四) Graph和Function

Graph

tf.graph在TensorFlow中主要用于性能优化。我们用TensorFlow写代码时可以通过python的内部机制进行计算,这种操作称为Eagerly。而Graph操作会运用TensorFlow所内含的数据计算模块,相比之下更有效率。

Graph是用于tf.Operation操作的数据结构,而graph流中计算的基本数据单元是tf.Tensor,它们在graph上下文中所定义。这些graph还可以用于非python代码的互通。

详细来说,Graph作用于以下方面:

  • 通过在计算中折叠常数节点来静态推断张量的值(“常数折叠”)。
  • 独立的计算子部分,并在线程或设备之间分割它们。
  • 通过消除公共子表达式简化算术运算。

简单来说,使用graph的好处主要有:速度快支持并行支持多设备优化几个方面

定义Graph

定义和使用graph的方式是使用tf.function,可以选择用tf.function()构建函数,或者用@tf.function()装饰器。

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

在使用上,tf.function好像和常规的function没什么区别,但是实际上,一个function可能封装了多个tf.graph,在计算效率和部署上两者会有很大的不同。

tf.function运用graph的特性适用于其所调用的其它函数。

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()

Autograph

我们在定义一个tf.function时,实际上在TensorFlow的内部会用AutoGraph (tf.autograph)将其转换成一个新的函数。
下面是一个示例写一个Relu函数然后可以看看转化的样子:

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))

First branch, with graph: 1
Second branch, with graph: 0

转换以后的函数:

def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal do_return, retval_
            (do_return, retval_) = vars_

        def if_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)

还可以看看其graph的样子:

# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())

输出:

node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 987
  min_consumer: 12
}

Graph的多态性

Graph具有多态性,当为一个tf.function传递参数时,tf根据传递的参数类型来创建一个graph,并为这个graph赋予一个id。当之后传递相同类型的参数时,tf.function会根据参数类型来匹配之前创建过的graph
这里的参数类型必须是tensorflow的类型才能判定成同种参数!

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.

tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

tf.graph会根据传递参数的不同选择其记忆的graph,也因此可以实现性能的优化

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())

my_relu(x)
Args:
x: float32 Tensor, shape=()
Returns:
float32 Tensor, shape=()

my_relu(x=[1, -1])
Returns:
float32 Tensor, shape=(2,)

my_relu(x)
Args:
x: float32 Tensor, shape=(2,)
Returns:
float32 Tensor, shape=(2,)

Tracing

Tracing是graph的一个特性,它会捕捉TensorFlow中其需要的部分。下面一个例子中,由于print()不在捕捉范围内,下面这段代码中的print不会多次执行。

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
  
tf.config.run_functions_eagerly(False)

y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

Calculating MSE!

会发现 Calculating MSE! 只输出了一次。
其中tf.config.run_functions_eagerly(False)可以控制在tf函数中是否采用eagerly的方式。
如果用tf.print()一定会多次输出,因为在tracing的范围内。

Non-strict execution

当我们定义一个tf.graph时,tf.graph只操作observable的内容,包括:

  • 函数的返回值
  • 一些特定的操作,如:
    • 输入或输出算子,如tf.operation
    • Debug操作:tf.Debugging
    • tf.Variable相关

这种行为称为Non-strict execution,与之无关的操作,哪怕存在error,执行时都不会管,如下所示:

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')

InvalidArgumentError: indices[0] = 1 is not in [0, 1) [Op:GatherV2]

这个例子中,如果定义为tf.function,即使存在索引问题,graph也不会处理:

@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))

tf.Tensor([0.], shape=(1,), dtype=float32)

posted @ 2022-05-12 21:46  Asp1rant  阅读(166)  评论(0编辑  收藏  举报