简单粗暴的tensorflow-变量保存与恢复
01 02 03 04 05 06 07 08 09 10 11 12 13 | # train.py 模型训练阶段 model = MyModel() # 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入) checkpoint = tf.train.Checkpoint(myModel = model) # ...(模型训练代码) # 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次) checkpoint.save( './save/model.ckpt' ) #保存模型参数,save_path_with_prefix目录+前缀 # test.py 模型使用阶段 model = MyModel() checkpoint = tf.train.Checkpoint(myModel = model) # 实例化Checkpoint,指定恢复对象为model checkpoint.restore(tf.train.latest_checkpoint( './save' )) # 从文件恢复模型参数 # 模型使用代码 |
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | import tensorflow as tf import numpy as np import argparse from zh.model.mnist.mlp import MLP from zh.model.utils import MNISTLoader parser = argparse.ArgumentParser(description = 'Process some integers.' ) parser.add_argument( '--mode' , default = 'train' , help = 'train or test' ) parser.add_argument( '--num_epochs' , default = 1 ) parser.add_argument( '--batch_size' , default = 50 ) parser.add_argument( '--learning_rate' , default = 0.001 ) args = parser.parse_args() data_loader = MNISTLoader() def train(): model = MLP() optimizer = tf.keras.optimizers.Adam(learning_rate = args.learning_rate) num_batches = int (data_loader.num_train_data / / args.batch_size * args.num_epochs) checkpoint = tf.train.Checkpoint(myAwesomeModel = model) # 实例化Checkpoint,设置保存对象为model for batch_index in range ( 1 , num_batches + 1 ): X, y = data_loader.get_batch(args.batch_size) with tf.GradientTape() as tape: y_pred = model(X) loss = tf.keras.losses.sparse_categorical_crossentropy(y_true = y, y_pred = y_pred) loss = tf.reduce_mean(loss) print ( "batch %d: loss %f" % (batch_index, loss.numpy())) grads = tape.gradient(loss, model.variables) optimizer.apply_gradients(grads_and_vars = zip (grads, model.variables)) if batch_index % 100 = = 0 : # 每隔100个Batch保存一次 path = checkpoint.save( './save/model.ckpt' ) # 保存模型参数到文件 print ( "model saved to %s" % path) def test(): model_to_be_restored = MLP() # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored checkpoint = tf.train.Checkpoint(myAwesomeModel = model_to_be_restored) checkpoint.restore(tf.train.latest_checkpoint( './save' )) # 从文件恢复模型参数 y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis = - 1 ) print ( "test accuracy: %f" % ( sum (y_pred = = data_loader.test_label) / data_loader.num_test_data)) if __name__ = = '__main__' : if args.mode = = 'train' : train() if args.mode = = 'test' : test() |
天道酬勤 循序渐进 技压群雄
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律