tensorflow中使用Batch Normalization
在深度学习中为了提高训练速度,经常会使用一些正正则化方法,如L2、dropout,后来Sergey Ioffe 等人提出Batch Normalization方法,可以防止数据分布的变化,影响神经网络需要重新学习分布带来的影响,会降低学习速率,训练时间等问题。提出使用batch normalization方法,使输入数据分布规律保持一致。实验证明可以提升训练速度,提高识别精度。下面讲解一下在Tensorflow中如何使用Batch Normalization
有关Batch Normalization详细内容请查看论文:
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
关键函数
tf.layers.batch_normalization、tf.contrib.layers.batch_norm
这两个函数用法一致,以 tf.layers.batch_normalization 为例进行讲解
layer1_conv = tf.layers.batch_normalization(layer1_conv,axis=0,training=in_training)
其中 axis 参数表示沿着哪个轴进行正则化,一般而言Tensor是[batch, width_x, width_y, channel],如果是[width_x, width_y, channel,batch]则axis应该设为3
1 在训练阶段
训练的时候需要注意两点,(1)输入参数training=True,(2)计算loss时,要添加以下代码(即添加
update_ops到最后的train_op中)。这样才能计算μ和σ的滑动平均(测试时会用到)
update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_op):
train_op = optimizer.minimize(loss)
2 在测试阶段
测试时需要注意一点,输入参数training=False,
如果觉得有用,想赞助一下请移步赞助页面:赞助一下
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)