tflearn 在每一个epoch完毕保存模型
1 2 3 | 关键代码:<br>tflearn.DNN(net, checkpoint_path = 'model_resnet_cifar10' , max_checkpoints = 10 , tensorboard_verbose = 0 , clip_gradients = 0. ) |
1 | snapshot_epoch = True , # Snapshot (save & evaluate) model every epoch.<br>我的demo: |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | def get_model(width, height, classes = 40 ): # TODO, modify model network = input_data(shape = [ None , width, height, 3 ]) # if RGB, 224,224,3 # Residual blocks # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18 n = 2 net = tflearn.conv_2d(network, 16 , 3 , regularizer = 'L2' , weight_decay = 0.0001 ) net = tflearn.residual_block(net, n, 16 ) net = tflearn.residual_block(net, 1 , 32 , downsample = True ) net = tflearn.residual_block(net, n - 1 , 32 ) net = tflearn.residual_block(net, 1 , 64 , downsample = True ) net = tflearn.residual_block(net, n - 1 , 64 ) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu' ) net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, classes, activation = 'softmax' ) #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True) mom = tflearn.Momentum( 0.01 , lr_decay = 0.1 , decay_step = 2000 , staircase = True ) net = tflearn.regression(net, optimizer = mom, loss = 'categorical_crossentropy' ) # Training model = tflearn.DNN(net, checkpoint_path = 'model_resnet_cifar10' , max_checkpoints = 10 , tensorboard_verbose = 0 , clip_gradients = 0. ) return model def main(): trainX, trainY = image_preloader( "data/train" , image_shape = (width, height, 3 ), mode = 'folder' , categorical_labels = True , normalize = True ) testX, testY = image_preloader( "data/test" , image_shape = (width, height, 3 ), mode = 'folder' , categorical_labels = True , normalize = True ) #trainX = trainX.reshape([-1, width, height, 1]) #testX = testX.reshape([-1, width, height, 1]) print ( "sample data:" ) print (trainX[ 0 ]) print (trainY[ 0 ]) print (testX[ - 1 ]) print (testY[ - 1 ]) model = get_model(width, height, classes = 3755 ) filename = 'tflearn_resnet/model.tflearn' # try to load model and resume training try : #model.load(filename) model.load( "model_resnet_cifar10-195804" ) print ( "Model loaded OK. Resume training!" ) except : pass early_stopping_cb = EarlyStoppingCallback(val_acc_thresh = 0.94 ) try : model.fit(trainX, trainY, validation_set = (testX, testY), n_epoch = 500 , shuffle = True , snapshot_epoch = True , # Snapshot (save & evaluate) model every epoch. show_metric = True , batch_size = 1024 , callbacks = early_stopping_cb, run_id = 'cnn_handwrite' ) except StopIteration as e: print ( "OK, stop iterate!Good!" ) model.save(filename) del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:] filename = 'tflearn_resnet/model-infer.tflearn' model.save(filename) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2017-05-08 elasticsearch best_fields most_fields cross_fields从内在实现看区别——本质就是前两者是以field为中心,后者是词条为中心