【482】Keras 实现 LSTM & BiLSTM
LSTM 是优秀的循环神经网络 (RNN) 结构,而 LSTM 在结构上也比较复杂,对 RNN 和 LSTM 还稍有疑问的朋友可以参考:Recurrent Neural Networks vs LSTM【参考李宏毅老师的讲课PPT内容】
这里我们将要使用 Keras 搭建 LSTM.Keras 封装了一些优秀的深度学习框架的底层实现,使用起来相当简洁,甚至不需要深度学习的理论知识,你都可以轻松快速的搭建你的深度学习网络,强烈推荐给刚入门深度学习的同学使用,当然我也是还没入门的那个。Keras:https://keras.io/,keras的backend有,theano,TensorFlow、CNTk,这里我使用的是 TensorFlow。
下面我们就开始搭建 LSTM & BiLSTM,实现 mnist 数据的分类。
一、加载包和定义参数
mnist 的 image 是 28*28 的 shape,我们定义 LSTM 的 input 为 (28,),将 image 一行一行地输入到 LSTM 的 cell 中,这样 time_step 就是 28,表示一个 image 有 28 行,LSTM 的 output 是 30 个。
1 2 3 4 5 6 7 8 9 10 | from tensorflow import keras import mnist from keras.layers import Dense, LSTM, Bidirectional from keras.utils import to_categorical from keras.models import Sequential # parameters for LSTM nb_lstm_outputs = 30 # 输出神经元个数 nb_time_steps = 28 # 时间序列的长度 nb_input_vectors = 28 # 每个输入序列的向量维度 |
二、数据预处理
特别注意 label 要使用 one_hot encoding,x_train 的 shape 为 (60000, 28,28)
1 2 3 4 5 6 7 8 9 10 11 12 13 | # data preprocessing x_train = mnist.train_images() y_train = mnist.train_labels() x_test = mnist.test_images() y_test = mnist.test_labels() # Nomalize the images x_train = (x_train / 255 ) - 0.5 x_test = (x_test / 255 ) - 0.5 # one_hot encoding y_train = to_categorical(y_train, num_classes = 10 ) y_test = to_categorical(y_test, num_classes = 10 ) |
三、搭建模型 (LSTM, BiLSTM)
keras 搭建模型相当简单,只需要在 Sequential 容器中不断 add 新的 layer 就可以了。
1 2 3 4 | # building model model = Sequential() model.add(LSTM(units = nb_lstm_outputs, input_shape = (nb_time_steps, nb_input_vectors))) model.add(Dense( 10 , activation = 'softmax' )) |
BiLSTM 模型搭建如下:具体实现方法差别不大
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # building model model = Sequential() model.add( Bidirectional( LSTM( units = nb_lstm_outputs, return_sequences = True ), input_shape = (nb_time_steps, nb_input_vectors) ) ) model.add( Bidirectional( LSTM(units = nb_lstm_outputs) ) ) model.add( Dense( 10 , activation = 'softmax' ) ) |
四、compile
模型 compile,指定 loss function,optimizer,metrics
1 2 3 4 5 6 | # compile:loss, optimizer, metrics model. compile ( loss = 'categorical_crossentropy' , optimizer = 'adam' , metrics = [ 'accuracy' ] ) |
五、summary
可以使用 model.summary() 来查看你的神经网络的架构和参数量等信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | model.summary() output: _________________________________________________________________ Layer ( type ) Output Shape Param # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = lstm_1 (LSTM) ( None , 30 ) 7080 _________________________________________________________________ dense_1 (Dense) ( None , 10 ) 310 = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = Total params: 7 , 390 Trainable params: 7 , 390 Non - trainable params: 0 _________________________________________________________________ |
BiLSTM 结果如下:多了一层 layer
1 2 3 4 5 6 7 8 9 10 11 12 13 | _________________________________________________________________ Layer ( type ) Output Shape Param # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = bidirectional_1 (Bidirection ( None , 28 , 60 ) 14160 _________________________________________________________________ bidirectional_2 (Bidirection ( None , 60 ) 21840 _________________________________________________________________ dense_2 (Dense) ( None , 10 ) 610 = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = Total params: 36 , 610 Trainable params: 36 , 610 Non - trainable params: 0 _________________________________________________________________ |
六、train
模型训练,需要指定,epochs 训练的轮次数,batch_size。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | model.fit( x_train, y_train, epochs = 20 , batch_size = 128 , verbose = 1 ) output: Epoch 1 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 11s 184us / step - loss: 0.9702 - acc: 0.6919 Epoch 2 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 9s 152us / step - loss: 0.3681 - acc: 0.8921 Epoch 3 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 9s 143us / step - loss: 0.2505 - acc: 0.9263 Epoch 4 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 9s 147us / step - loss: 0.1985 - acc: 0.9411 Epoch 5 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 9s 156us / step - loss: 0.1673 - acc: 0.9508 Epoch 6 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 163us / step - loss: 0.1473 - acc: 0.9563 Epoch 7 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 162us / step - loss: 0.1311 - acc: 0.9605 Epoch 8 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 162us / step - loss: 0.1176 - acc: 0.9650 Epoch 9 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 167us / step - loss: 0.1054 - acc: 0.9688 Epoch 10 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 165us / step - loss: 0.0991 - acc: 0.9702 Epoch 11 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 164us / step - loss: 0.0899 - acc: 0.9730 Epoch 12 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 169us / step - loss: 0.0857 - acc: 0.9741 Epoch 13 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 166us / step - loss: 0.0781 - acc: 0.9758 Epoch 14 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 167us / step - loss: 0.0740 - acc: 0.9776 Epoch 15 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 172us / step - loss: 0.0697 - acc: 0.9786 Epoch 16 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 171us / step - loss: 0.0678 - acc: 0.9795 Epoch 17 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 170us / step - loss: 0.0639 - acc: 0.9798 Epoch 18 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 169us / step - loss: 0.0589 - acc: 0.9817 Epoch 19 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 172us / step - loss: 0.0597 - acc: 0.9817 Epoch 20 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 10s 168us / step - loss: 0.0558 - acc: 0.9825 |
BiLSTM 结果如下:结果更好
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | Epoch 1 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 767us / step - loss: 0.6845 - acc: 0.7782 Epoch 2 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 48s 799us / step - loss: 0.1843 - acc: 0.9435 Epoch 3 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 751us / step - loss: 0.1241 - acc: 0.9627 Epoch 4 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 747us / step - loss: 0.0956 - acc: 0.9712 Epoch 5 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 766us / step - loss: 0.0806 - acc: 0.9754 Epoch 6 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 771us / step - loss: 0.0667 - acc: 0.9793 Epoch 7 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 754us / step - loss: 0.0584 - acc: 0.9820 Epoch 8 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 44s 741us / step - loss: 0.0513 - acc: 0.9835 Epoch 9 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 742us / step - loss: 0.0445 - acc: 0.9863 Epoch 10 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 767us / step - loss: 0.0419 - acc: 0.9874 Epoch 11 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 755us / step - loss: 0.0378 - acc: 0.9885 Epoch 12 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 758us / step - loss: 0.0332 - acc: 0.9894 Epoch 13 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 750us / step - loss: 0.0318 - acc: 0.9894 Epoch 14 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 756us / step - loss: 0.0279 - acc: 0.9911 Epoch 15 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 745us / step - loss: 0.0262 - acc: 0.9917 Epoch 16 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 758us / step - loss: 0.0258 - acc: 0.9916 Epoch 17 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 47s 791us / step - loss: 0.0226 - acc: 0.9923 Epoch 18 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 47s 791us / step - loss: 0.0223 - acc: 0.9930 Epoch 19 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 46s 773us / step - loss: 0.0179 - acc: 0.9943 Epoch 20 / 20 60000 / 60000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 45s 747us / step - loss: 0.0199 - acc: 0.9935 |
七、evaluate
通过 model.evaluate() 来实现。
1 2 3 4 5 6 7 | score = model.evaluate(x_test, y_test, batch_size = 128 , verbose = 1 ) print (score) output: 10000 / 10000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 0s 49us / step [ 0.06827456439994276 , 0.9802 ] |
BiLSTM 结果:更好
1 2 | 10000 / 10000 [ = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = ] - 2s 250us / step [ 0.055307343754824254 , 0.9838 ] |
posted on 2020-09-24 22:16 McDelfino 阅读(5492) 评论(1) 编辑 收藏 举报
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2019-09-24 【439】Tweets processing by Python