TensorFlow保存和载入模型
首先定义一个tf.train.Saver类:
saver = tf.train.Saver(max_to_keep=1)
其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,如果设置成0,训练过程中的所有模型都会被保存。
模型训练好以后,保存模型:
saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)
其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,由于我们只保存最后一个模型,所以可以设置为1,如果每一个模型都想保存,可以设置成训练的epoch。
载入模型比较简单:
saver.restore(sess, model_file)
其中,sess是Session,model_file是模型的路径和名称。
【推荐】还在用 ECharts 开发大屏?试试这款永久免费的开源 BI 工具!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步