Tensorflow2.0实现VGG13

导入必要的库:

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
tf.random.set_seed(2345)

其中os.environ部分是为了减少Tensorflow打印的信息
构建网络结构:

conv_layers=[
    layers.Conv2D(64,kernel_size=[3,3],padding="same",activation=tf.nn.relu),
    layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2,2],strides=2,padding="same"),

    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),

    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),

    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),

    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding="same"),
]

优化器:

def preprocess(x,y):
    x=tf.cast(x,dtype=tf.float32)/255.
    y=tf.cast(y,dtype=tf.int32)
    return x,y

加载数据:
这里使用比较常见的CIFAR10的数据集

(x_train,y_train),(x_test,y_test)=datasets.cifar10.load_data()
y_train=tf.squeeze(y_train,axis=1)
y_test=tf.squeeze(y_test,axis=1)
# print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
train_data=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_data=train_data.shuffle(1000).map(preprocess).batch(64)

test_data=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_data=test_data.map(preprocess).batch(64)

sample=next(iter(train_data))
print('sample:',sample[0].shape,sample[1].shape,
      tf.reduce_min(sample[0]),tf.reduce_max(sample[0]))

sample=next(iter(train_data))
这一部分是打印train_data的信息
完善网络:

def main():
    conv_net=Sequential(conv_layers)
    # x=tf.random.normal([4,32,32,3])
    # out=conv_net(x)
    # print(out.shape)
    fc_net=Sequential([
        layers.Dense(256,activation=tf.nn.relu),
        layers.Dense(128,activation=tf.nn.relu),
        layers.Dense(10,activation=None),
    ])
    conv_net.build(input_shape=[None, 32, 32, 3])
    fc_net.build(input_shape=[None,512])
    optimizer=optimizers.Adam(lr=1e-4)

计算loss:

variables=conv_net.trainable_variables+fc_net.trainable_variables
    for epoch in range(50):
        for step,(x,y) in enumerate(train_data):
            with tf.GradientTape() as tape:
                out=conv_net(x)
                out=tf.reshape(out,[-1,512])
                logits=fc_net(out)
                y_onehot=tf.one_hot(y,depth=10)
                loss=tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
                loss=tf.reduce_mean(loss)
            grads=tape.gradient(loss,variables)
            optimizer.apply_gradients(zip(grads,variables))
            if step%100==0:
                print(epoch,step,'loss',float(loss))

测试:

total_num=0
        total_correct=0
        for x,y in test_data:
            out=conv_net(x)
            out=tf.reshape(out,[-1,512])
            logits=fc_net(out)
            prob=tf.nn.softmax(logits,axis=1)
            pred=tf.argmax(prob,axis=1)
            pred=tf.cast(pred,dtype=tf.int32)
            correct=tf.cast(tf.equal(pred,y),dtype=tf.int32)
            correct=tf.reduce_sum(correct)
            total_num+=x.shape[0]
            total_correct+=int(correct)
        acc=total_correct/total_num
        print(epoch,'acc:',acc)
if __name__ == '__main__':
    main()

训练数据:

0 0 loss 2.302990436553955
0 100 loss 1.9521405696868896
0 200 loss 1.9435423612594604
0 300 loss 1.6067744493484497
0 400 loss 1.5959546566009521
0 500 loss 1.734712839126587
0 600 loss 1.2384529113769531
0 700 loss 1.3307044506072998
0 acc: 0.4787
5 0 loss 0.6936513185501099
5 100 loss 0.7874761819839478
5 200 loss 0.7884306907653809
5 300 loss 0.6663026809692383
5 400 loss 0.4075947105884552
5 500 loss 0.6752095222473145
5 600 loss 0.5246847867965698
5 700 loss 0.5275574922561646
5 acc: 0.7299
10 0 loss 0.7874808311462402
10 100 loss 0.5072851181030273
10 200 loss 0.4451877772808075
10 300 loss 0.177499920129776
10 400 loss 0.13723205029964447
10 500 loss 0.2971668243408203
10 600 loss 0.25279730558395386
10 700 loss 0.36453887820243835
10 acc: 0.7355
15 0 loss 0.2800075113773346
15 100 loss 0.1841358095407486
15 200 loss 0.040746696293354034
15 300 loss 0.06615383923053741
15 400 loss 0.1183178648352623
15 500 loss 0.07481158524751663
15 600 loss 0.09398414194583893
15 700 loss 0.03665520250797272
15 acc: 0.7469
20 0 loss 0.02290465496480465
20 100 loss 0.008633529767394066
20 200 loss 0.21534058451652527
20 300 loss 0.011568240821361542
20 400 loss 0.08179830759763718
20 500 loss 0.02673691138625145
20 600 loss 0.06506452709436417
20 700 loss 0.026200752705335617
20 acc: 0.7621

训练大概50epoch,这里仅仅展示20个,可以看到,验证准确率是在不断的上升的,后面的数据就不展示了,我也没训练完,有兴趣的可以接着跑将模型保存一下,有时间再接着训练

posted @   陶陶Name  阅读(70)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 从HTTP原因短语缺失研究HTTP/2和HTTP/3的设计差异
· 三行代码完成国际化适配,妙~啊~
点击右上角即可分享
微信分享提示