tflearn 保存模型重新训练
from:https://stackoverflow.com/questions/41616292/how-to-load-and-retrain-tflean-model
This is to create a graph and save it
graph1 = tf.Graph()
with graph1.as_default():
network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
network = merge([branch1, branch2, branch3], mode='concat', axis=1)
network = tf.expand_dims(network, 2)
network = global_max_pool(network)
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
model = tflearn.DNN(network, tensorboard_verbose=0)
clf, acc, roc_auc,fpr,tpr =classify_DNN(data,clas,model)
clf.save(model_path)
To reload and retrain or use it for prediction
MODEL = None
with tf.Graph().as_default():
## Building deep neural network
network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
network = merge([branch1, branch2, branch3], mode='concat', axis=1)
network = tf.expand_dims(network, 2)
network = global_max_pool(network)
network = dropout(network, 0.5)
network = fully_connected(network, 2, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
new_model = tflearn.DNN(network, tensorboard_verbose=3)
new_model.load(model_path)
MODEL = new_model
Use the MODEL for prediction or retraining. The 1st line and the with loop was important. For anyone who might need help
官方例子:
""" An example showing how to save/restore models and retrieve weights. """ from __future__ import absolute_import, division, print_function import tflearn import tflearn.datasets.mnist as mnist # MNIST Data X, Y, testX, testY = mnist.load_data(one_hot=True) # Model input_layer = tflearn.input_data(shape=[None, 784], name='input') dense1 = tflearn.fully_connected(input_layer, 128, name='dense1') dense2 = tflearn.fully_connected(dense1, 256, name='dense2') softmax = tflearn.fully_connected(dense2, 10, activation='softmax') regression = tflearn.regression(softmax, optimizer='adam', learning_rate=0.001, loss='categorical_crossentropy') # Define classifier, with model checkpoint (autosave) model = tflearn.DNN(regression, checkpoint_path='model.tfl.ckpt') # Train model, with model checkpoint every epoch and every 200 training steps. model.fit(X, Y, n_epoch=1, validation_set=(testX, testY), show_metric=True, snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch. snapshot_step=500, # Snapshot (save & evalaute) model every 500 steps. run_id='model_and_weights') # --------------------- # Save and load a model # --------------------- # Manually save model model.save("model.tfl") # Load a model model.load("model.tfl") # Or Load a model from auto-generated checkpoint # >> model.load("model.tfl.ckpt-500") # Resume training model.fit(X, Y, n_epoch=1, validation_set=(testX, testY), show_metric=True, snapshot_epoch=True, run_id='model_and_weights') # ------------------ # Retrieving weights # ------------------ # Retrieve a layer weights, by layer name: dense1_vars = tflearn.variables.get_layer_variables_by_name('dense1') # Get a variable's value, using model `get_weights` method: print("Dense1 layer weights:") print(model.get_weights(dense1_vars[0])) # Or using generic tflearn function: print("Dense1 layer biases:") with model.session.as_default(): print(tflearn.variables.get_value(dense1_vars[1])) # It is also possible to retrieve a layer weights through its attributes `W` # and `b` (if available). # Get variable's value, using model `get_weights` method: print("Dense2 layer weights:") print(model.get_weights(dense2.W)) # Or using generic tflearn function: print("Dense2 layer biases:") with model.session.as_default(): print(tflearn.variables.get_value(dense2.b))
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」