4-3AutoGraph的使用规范——eat_tensorflow2_in_30_days

4-3 AutoGraph的使用规范

有三种计算图的构建方式:静态计算图,动态计算图,以及Autograph

TensorFlow2.0主要使用的是动态计算图和Autograph

  • 动态计算图易于调试,编码效率较高,但执行效率较低
  • 静态计算图执行效率很高,但较难调试
  • 而Autograph机制可以将动态计算图转换成静态计算图,兼收执行效率和编码效率之利

当然Autograph 机制能够转换的代码并不是没哟任何约束的,有一些编码规范需要遵循,否则可能会转换失败或者不符合预期

这里将着重介绍Autograph的编码规范和Autograph转换成静态图的原理

并介绍使用tf.Module来更好地构建Autograph

Autograph编码规范总结#

  • 被@tf.function修饰的函数应尽可能使用TensorFlow中的函数而不是Python中的其他函数。例如使用tf.print而不是print,使用tf.range而不是range,使用tf.constant(True)而不是True
  • 避免在@tf.function修饰的函数内部定义tf.Variable
  • 被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等数据结构变量

Autograph编码规范解析#

  • 被@tf.function修饰的函数应尽可能使用TensorFlow中的函数而不是Python中的其他函数
import numpy as np
import tensorflow as tf

@tf.function
def np_random():
    # np.random.randn函数返回一个或一组样本,具有标准正太分布
    a = np.random.randn(3, 3)
    tf.print(a)
    
@tf.function
def tf_random():
    
    # tf.random.normal服从指定正态分布的序列
    a = tf.random.normal((3, 3))
    tf.print(a)
    
# np.random每次执行都是一样的结果
np_random()
np_random()

"""
array([[-0.84051143, -0.11712408, -0.17738803],
       [ 0.7147196 ,  1.42842053, -0.56037017],
       [-0.0487268 ,  1.05235275,  1.01622511]])
array([[-0.84051143, -0.11712408, -0.17738803],
       [ 0.7147196 ,  1.42842053, -0.56037017],
       [-0.0487268 ,  1.05235275,  1.01622511]])
"""

# tf_random每次执行都会有重新生成随机数
tf_random()
tf_random()

"""
[[1.19916523 0.203395322 1.3903774]
 [-2.06304955 -0.38222155 -1.46414936]
 [0.491630137 0.0822804719 -0.254222572]]
[[-0.549568892 2.08878803 0.558463752]
 [-0.36475572 0.136399537 -0.0849579573]
 [0.253954887 -0.276775241 -1.54324198]]
"""
  • 避免在@tf.function修饰的函数内部定义tf.Variable
# 避免在@tf.function修饰的函数内部定义tf.Variable
x = tf.Variable(1.0, dtype=tf.float32)

@tf.function
def outer_var():
    x.assign_add(1.0)
    tf.print(x)
    return x

outer_var()
outer_var()

"""
2
3
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>
"""
@tf.function
def inner_var():
    x = tf.Variable(1.0, dtype=tf.float32)
    x.assign_add(1.0)
    tf.print(x)
    return x

#  执行将报错
# inner_var()
  • 被@tf.function修饰的函数不可修改该函数外部的Python列表或字典等数据结构变量
tensor_list = []

@tf.function  # 加上这一行换成Autograph结果将不符合预期
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list

append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

"""
[<tf.Tensor 'x:0' shape=() dtype=float32>]
"""
tensor_list = []

# @tf.function  # 加上这一行换成Autograph结果将不符合预期
def append_tensor(x):
    tensor_list.append(x)
    return tensor_list

append_tensor(tf.constant(5.0))
append_tensor(tf.constant(6.0))
print(tensor_list)

"""
[<tf.Tensor: shape=(), dtype=float32, numpy=5.0>, <tf.Tensor: shape=(), dtype=float32, numpy=6.0>]
"""

作者:lotuslaw

出处:https://www.cnblogs.com/lotuslaw/p/16398919.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   lotuslaw  阅读(35)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
more_horiz
keyboard_arrow_up light_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示