tensorflow1.x版本代码迁移到2.0
由于3090显卡只支持tf2.0以后的版本,而且随着显卡的更新换代,tf1.x版本也不支持更高级的显卡,所以有必要将1.x版本的代码转成2.0后的版本。
Tf2.0版本和tf1.0版本的主要区别
主要区别在于tf1.x是静态图,需要先把模型结构先定好,再进行训练
Tf2.0版本则是动态图,训练前不用先构建完整的结构,而是按流程一步步构建,因此在训练的时候tf1.x相比于tf2.0占cpu内存大,训练的速度更快
代码转换主要分几个方面:输入、模型网络、训练、模型保存
1.输入
在1.x的代码中,对于输入需要首先加placeholder,作为整个网络的入口。而tf2.0取消了这个部分,因此修改的方法是去掉这部分代码,直接在训练的时候输入数据,例如:
Tf1.x:
self.inputs = tf.placeholder(tf.int32, [None, None], name="inputs") # 数据输入 self.labels = tf.placeholder(tf.float32, [None, None], name="labels") # 标签
修改后直接在训练的时候赋值就行:
self.inputs = batch["x"] self.labels = batch["y"] self.keep_prob = dropout_prob
2.模型网络
这部分比较好改,因为很多api可以在tensorflow官方文档上找到相应的替换函数,几个常用的如下:
tf.get_variable()变成tf.variable()
Initializer的改变
# embedding_w = tf.compat.v1.get_variable("embedding_w", shape=[self.vocab_size, self.config["embedding_size"]], # initializer=tf.compat.v1.contrib.layers.xavier_initializer()) embedding_w = tf.Variable(tf.keras.initializers.glorot_normal()(shape=[self.vocab_size, self.config["embedding_size"]], dtype=tf.float32), name='embedding')
3.训练
训练过程包括梯度的操作、优化算法的选择,主要的操作如下:
模型训练要继承tf.Module这个api,因为训练的时候要选择状态容器以便存储模型的参数,如果用keras或estimator模块写模型也可以继承其他的api,具体的继承规则可以参考这个树形结构:
https://zhuanlan.zhihu.com/p/73575776
Trackable
|
|-- tf.Variable
|
|-- MutableHashTable
|
|-- AutoTrackable
|
|-- ListWrapper/DictWrapper
|
|-- tf.train.Checkpoint
|
|-- tf.Module
|
|-- tf.keras.layers.Layer
|
|-- tf.keras.Model
|
|-- tf.keras.Sequential
几种状态容器的选择准则一般为:
仅在学习和深入研究状态容器(或基于对象的储存)时使用Trackable和AutoTrackable
tf.Module: 适合自定义训练循环时使用
tf.keras.layers.Layer:适合实现一些中间层,比如Attention之类的,可以配合tf.keras.Sequential使用,极少看见大的模型继承自这个类型。
tf.keras.Model:适合一些固定套路的模型(使用compile + fit)。虽然也可以自定义训练循环,但是有一种杀鸡用牛刀的感觉。
tf.keras.Sequential:适合一条路走到黑的(子)模型。
选择完状态容器后则要进行对应的训练循环,也就是梯度下降的操作:
Tf1.x首先定义好train_op,然后session.run
Tf2.0则直接在epoch循环内使用
with tf.GradientTape() as t: grads = t.gradient(self.model.loss, self.model.trainable_variables) optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
也就是将sess.run里面的操作换成一步步执行的函数流程
4.模型保存
Tf1.x和2.0的模型保存变化不大,都可以保存成checkepoint和pb这两种格式,根据文档将api换一下就可以了,但是需要注意的是保存的模型加载的时候版本需要和之前一致,否则在模型加载的时候可能会报错。Summary的保存也是一样,需要把api替换掉。