基于keras的残差网络
1.使用RBF神经网络实现函数逼近2.使用BP神经网络实现函数逼近3.使用CNN实现MNIST数据集分类4.tensorflow解决回归问题简单案列5.自然语言建模与词向量6.使用多层RNN-LSTM网络实现MNIST数据集分类及常见坑汇总7.基于tensorflow的RBF神经网络案例8.快速傅里叶变换(FFT)和小波分析在信号处理上的应用9.Python三次样条插值与MATLAB三次样条插值简单案例10.使用TensorFlow实现MNIST数据集分类11.数值积分原理与应用12.keras建模的3种方式——序列模型、函数模型、子类模型
13.基于keras的残差网络
14.基于keras的双层LSTM网络和双向LSTM网络15.seq2seq模型案例分析16.基于keras的卷积神经网络(CNN)17.基于keras的时域卷积网络(TCN)18.基于keras的胶囊网络(CapsNet)1 前言
理论上,网络层数越深,拟合效果越好。但是,层数加深也会导致梯度消失或梯度爆炸现象产生。当网络层数已经过深时,深层网络表现为“恒等映射”。实践表明,神经网络对残差的学习比对恒等关系的学习表现更好,因此,残差网络在深层模型中广泛应用。
本文以MNIST手写数字分类为例,为方便读者认识残差网络,网络中只有全连接层,没有卷积层。关于MNIST数据集的说明,见使用TensorFlow实现MNIST数据集分类
笔者工作空间如下:
代码资源见-->残差网络(ResNet)案例分析
2 实验
renet.py
from tensorflow.examples.tutorials.mnist import input_data
from keras.models import Model
from keras.layers import add,Input,Dense,Activation
#载入数据
def read_data(path):
mnist=input_data.read_data_sets(path,one_hot=True)
train_x,train_y=mnist.train.images,mnist.train.labels,
valid_x,valid_y=mnist.validation.images,mnist.validation.labels,
test_x,test_y=mnist.test.images,mnist.test.labels
return train_x,train_y,valid_x,valid_y,test_x,test_y
#残差块
def ResBlock(x,hidden_size1,hidden_size2):
r=Dense(hidden_size1,activation='relu')(x) #第一隐层
r=Dense(hidden_size2)(r) #第二隐层
if x.shape[1]==hidden_size2:
shortcut=x
else:
shortcut=Dense(hidden_size2)(x) #shortcut(捷径)
o=add([r,shortcut])
o=Activation('relu')(o) #激活函数
return o
#残差网络
def ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y):
inputs=Input(shape=(784,))
x=ResBlock(inputs,30,30)
x=ResBlock(x,30,30)
x=ResBlock(x,20,20)
x=Dense(10,activation='softmax')(x)
model=Model(input=inputs,output=x)
#查看网络结构
model.summary()
#编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
#训练模型
model.fit(train_x,train_y,batch_size=500,nb_epoch=50,verbose=2,validation_data=(valid_x,valid_y))
#评估模型
pre=model.evaluate(test_x,test_y,batch_size=500,verbose=2)
print('test_loss:',pre[0],'- test_acc:',pre[1])
train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data')
ResNet(train_x,train_y,valid_x,valid_y,test_x,test_y)
网络各层输出尺寸:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 784) 0
__________________________________________________________________________________________________
dense_1 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 30) 930 dense_1[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 30) 23550 input_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 30) 0 dense_2[0][0]
dense_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 30) 0 add_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 30) 930 activation_1[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 30) 930 dense_4[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 30) 0 dense_5[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 30) 0 add_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 20) 420 dense_6[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 20) 620 activation_2[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 20) 0 dense_7[0][0]
dense_8[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 20) 0 add_3[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 10) 210 activation_3[0][0]
==================================================================================================
Total params: 51,760
Trainable params: 51,760
Non-trainable params: 0
网络训练结果:
Epoch 48/50
- 1s - loss: 0.0019 - acc: 0.9999 - val_loss: 0.1463 - val_acc: 0.9706
Epoch 49/50
- 1s - loss: 0.0016 - acc: 0.9999 - val_loss: 0.1502 - val_acc: 0.9722
Epoch 50/50
- 1s - loss: 0.0013 - acc: 0.9999 - val_loss: 0.1542 - val_acc: 0.9728
test_loss: 0.16228994959965348 - test_acc: 0.9721000045537949
声明:本文转自基于keras的残差网络
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)