【509】NLP实战系列(六)—— 通过 LSTM 来做分类
参考:LSTM层
1. 语法
1 | keras.layers.recurrent.LSTM(units, activation = 'tanh' , recurrent_activation = 'hard_sigmoid' , use_bias = True , kernel_initializer = 'glorot_uniform' , recurrent_initializer = 'orthogonal' , bias_initializer = 'zeros' , unit_forget_bias = True , kernel_regularizer = None , recurrent_regularizer = None , bias_regularizer = None , activity_regularizer = None , kernel_constraint = None , recurrent_constraint = None , bias_constraint = None , dropout = 0.0 , recurrent_dropout = 0.0 ) |
2. 参数
-
units:输出维度
-
activation:激活函数,为预定义的激活函数名(参考激活函数)
-
recurrent_activation: 为循环步施加的激活函数(参考激活函数)
-
use_bias: 布尔值,是否使用偏置项
-
kernel_initializer:权值初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers
-
recurrent_initializer:循环核的初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers
-
bias_initializer:权值初始化方法,为预定义初始化方法名的字符串,或用于初始化权重的初始化器。参考initializers
-
kernel_regularizer:施加在权重上的正则项,为Regularizer对象
-
bias_regularizer:施加在偏置向量上的正则项,为Regularizer对象
-
recurrent_regularizer:施加在循环核上的正则项,为Regularizer对象
-
activity_regularizer:施加在输出上的正则项,为Regularizer对象
-
kernel_constraints:施加在权重上的约束项,为Constraints对象
-
recurrent_constraints:施加在循环核上的约束项,为Constraints对象
-
bias_constraints:施加在偏置上的约束项,为Constraints对象
-
dropout:0~1之间的浮点数,控制输入线性变换的神经元断开比例
-
recurrent_dropout:0~1之间的浮点数,控制循环状态的线性变换的神经元断开比例
-
其他参数参考Recurrent的说明
3. 具体实现
3.1 加载数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | from keras.datasets import imdb from keras.preprocessing import sequence max_features = 10000 # number of words to consider as features maxlen = 500 # cut texts after this number of words (among top max_features most common words) batch_size = 32 print ( 'Loading data...' ) (input_train, y_train), (input_test, y_test) = imdb.load_data(num_words = max_features) print ( len (input_train), 'train sequences' ) print ( len (input_test), 'test sequences' ) print ( 'Pad sequences (samples x time)' ) input_train = sequence.pad_sequences(input_train, maxlen = maxlen) input_test = sequence.pad_sequences(input_test, maxlen = maxlen) print ( 'input_train shape:' , input_train.shape) print ( 'input_test shape:' , input_test.shape) |
output:
Loading data... 25000 train sequences 25000 test sequences Pad sequences (samples x time) input_train shape: (25000, 500) input_test shape: (25000, 500)
3.2 数据训练
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | from keras.layers import LSTM model = Sequential() model.add(Embedding(max_features, 32 )) model.add(LSTM( 32 )) model.add(Dense( 1 , activation = 'sigmoid' )) model. compile (optimizer = 'rmsprop' , loss = 'binary_crossentropy' , metrics = [ 'acc' ]) history = model.fit(input_train, y_train, epochs = 10 , batch_size = 128 , validation_split = 0.2 ) |
outputs:
Train on 20000 samples, validate on 5000 samples Epoch 1/10 20000/20000 [==============================] - 108s - loss: 0.5038 - acc: 0.7574 - val_loss: 0.3853 - val_acc: 0.8346 Epoch 2/10 20000/20000 [==============================] - 108s - loss: 0.2917 - acc: 0.8866 - val_loss: 0.3020 - val_acc: 0.8794 Epoch 3/10 20000/20000 [==============================] - 107s - loss: 0.2305 - acc: 0.9105 - val_loss: 0.3125 - val_acc: 0.8688 Epoch 4/10 20000/20000 [==============================] - 107s - loss: 0.2033 - acc: 0.9261 - val_loss: 0.4013 - val_acc: 0.8574 Epoch 5/10 20000/20000 [==============================] - 107s - loss: 0.1749 - acc: 0.9385 - val_loss: 0.3273 - val_acc: 0.8912 Epoch 6/10 20000/20000 [==============================] - 107s - loss: 0.1543 - acc: 0.9457 - val_loss: 0.3505 - val_acc: 0.8774 Epoch 7/10 20000/20000 [==============================] - 107s - loss: 0.1417 - acc: 0.9493 - val_loss: 0.4485 - val_acc: 0.8396 Epoch 8/10 20000/20000 [==============================] - 106s - loss: 0.1331 - acc: 0.9522 - val_loss: 0.3242 - val_acc: 0.8928 Epoch 9/10 20000/20000 [==============================] - 106s - loss: 0.1147 - acc: 0.9618 - val_loss: 0.4216 - val_acc: 0.8746 Epoch 10/10 20000/20000 [==============================] - 106s - loss: 0.1092 - acc: 0.9628 - val_loss: 0.3972 - val_acc: 0.8758
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· .NET10 - 预览版1新功能体验(一)
2017-12-27 【280】◀▶ ArcPy 常用工具说明