tflearn 在每一个epoch完毕保存模型

1
2
3
关键代码:<br>tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
1
snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.<br>我的demo:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_model(width, height, classes=40):
    # TODO, modify model
    network = input_data(shape=[None, width, height, 3])  # if RGB, 224,224,3
    # Residual blocks 
    # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18 
    n = 2
    net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
    net = tflearn.residual_block(net, n, 16)
    net = tflearn.residual_block(net, 1, 32, downsample=True)
    net = tflearn.residual_block(net, n-1, 32)
    net = tflearn.residual_block(net, 1, 64, downsample=True)
    net = tflearn.residual_block(net, n-1, 64)
    net = tflearn.batch_normalization(net)
    net = tflearn.activation(net, 'relu')
    net = tflearn.global_avg_pool(net)
    # Regression 
    net = tflearn.fully_connected(net, classes, activation='softmax')
    #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
    mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
    net = tflearn.regression(net, optimizer=mom,
                             loss='categorical_crossentropy')
    # Training 
    model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                        max_checkpoints=10, tensorboard_verbose=0,
                        clip_gradients=0.)
    return model
 
 
 
def  main():
    trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
    testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
    #trainX = trainX.reshape([-1, width, height, 1])
    #testX = testX.reshape([-1, width, height, 1])
    print("sample data:")
    print(trainX[0])
    print(trainY[0])
    print(testX[-1])
    print(testY[-1])
 
    model = get_model(width, height, classes=3755)
 
    filename = 'tflearn_resnet/model.tflearn'
    # try to load model and resume training
    try:
        #model.load(filename)
        model.load("model_resnet_cifar10-195804")
        print("Model loaded OK. Resume training!")
    except:
        pass
 
    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
    try:     
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
    except StopIteration as e:
        print("OK, stop iterate!Good!")
 
    model.save(filename)
 
    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
    filename = 'tflearn_resnet/model-infer.tflearn'
    model.save(filename)

 

posted @   bonelee  阅读(1819)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2017-05-08 elasticsearch best_fields most_fields cross_fields从内在实现看区别——本质就是前两者是以field为中心,后者是词条为中心
点击右上角即可分享
微信分享提示