学习笔记428—Keras实现简单BP神经网络
Keras实现简单BP神经网络
BP 神经网络的简单实现
1
2
3
4
5
6
7
8
9
10
|
from keras.models import Sequential #导入模型 from keras.layers.core import Dense #导入常用层 train_x,train_y #训练集 test_x,text_y #测试集 model = Sequential() #初始化模型 model.add(Dense( 3 ,input_shape = ( 32 ,),activation = 'sigmoid' ,init = 'uniform' ))) #添加一个隐含层,注:只是第一个隐含层需指定input_dim model.add(Dense( 1 ,activation = 'sigmoid' )) #添加输出层 model. compile (loss = 'binary_crossentropy' , optimizer = 'sgd' , metrics = [ 'accuracy' ]) # 编译,指定目标函数与优化方法 model.fit(train_x,train_y ) # 模型训练 model.evaluate(test_x,text_y ) #模型测试 |
注意:在这个例子中,3: 这个数字代表该层的神经元数量,意味着该层将有 3 个神经元。 input_shape=(32,): 这个参数指定输入数据的形状。在这里,(32,) 表示输入数据是一个包含 32 个特征的一维数组。这意味着每个输入样本都是一个 32 维向量。 activation='sigmoid': 这个参数设置层激活函数的类型。sigmoid 激活函数将输出压缩到 (0, 1) 区间内,它对于二分类问题很有用,因为可以输出概率值。 init='uniform': 这个参数指定权重的初始化方式。uniform 表示权重将从一个均匀分布中随机初始化。然而,这个参数在 Keras 中已经被弃用,因为 Keras 默认使用一个合适的初始化器。如果你使用的是较新版本的 Keras,你不需要显式设置这个参数。
常用层
常用层对应于core模块,core内部定义了一系列常用的网络层,包括全连接、激活层等
Dense层
1
|
keras.layers.core.Dense(units, activation = None , use_bias = True , kernel_initializer = 'glorot_uniform' , bias_initializer = 'zeros' , kernel_regularizer = None , bias_regularizer = None , activity_regularizer = None , kernel_constraint = None , bias_constraint = None ) |
Dense就是常用的全连接层,所实现的运算是output = activation(dot(input, kernel)+bias)
。其中activation
是逐元素计算的激活函数,kernel
是本层的权值矩阵,bias
为偏置向量,只有当use_bias=True
才会添加。
如果本层的输入数据的维度大于2,则会先被压为与kernel
相匹配的大小。
1 2 3 4 5 6 7 8 9 10 11 | #example # as first layer in a sequential model: | model = Sequential() | model.add(Dense( 32 , input_shape = ( 16 ,))) | # now the model will take as input arrays of shape (*, 16) | # and output arrays of shape (*, 32) | | # after the first layer, you don't need to specify | # the size of the input anymore: | model.add(Dense( 32 )) |
Keras主要包括14个模块,本文主要对Models、layers、Initializations、Activations、Objectives、Optimizers、Preprocessing、metrics共计8个模块分别展开介绍。
1. Model
包:keras.models
这是Keras中最主要的一个模块,用于对各个组件进行组装
eg:
1 2 3 | from keras.models import Sequential model = Sequential() #初始化模型 model.add(...) #可使用add方法组装组件 |
2. layers
包:keras.layers
该模块主要用于生成神经网络层,包含多种类型,如Core layers、Convolutional layers等
eg:
1 2 | from keras.layers import Dense #Dense表示Bp层 model.add(Dense(input_dim = 3 ,output_dim = 5 )) #加入隐含层 |
3. Initializations
包:keras.initializations
该模块主要负责对模型参数(权重)进行初始化,初始化方法包括:uniform、lecun_uniform、normal、orthogonal、zero、glorot_normal、he_normal等
详细说明:http://keras.io/initializations/
eg:
1 | model.add(Dense(input_dim = 3 ,output_dim = 5 ,init = 'uniform' )) #加入带初始化(uniform)的隐含层 |
4. Activations
包:keras.activations、keras.layers.advanced_activations(新激活函数)
该模块主要负责为神经层附加激活函数,如linear、sigmoid、hard_sigmoid、tanh、softplus、relu、 softplus以及LeakyReLU等比较新的激活函数
详细说明:http://keras.io/activations/
eg:
1 | model.add(Dense(input_dim = 3 ,output_dim = 5 ,activation = 'sigmoid' )) 加入带激活函数(sigmoid)的隐含层 |
Equal to:
1 2 | model.add(Dense(input_dim = 3 ,output_dim = 5 )) model.add(Activation( 'sigmoid' )) |
5. Objectives
包:keras.objectives
该模块主要负责为神经网络附加损失函数,即目标函数。如mean_squared_error,mean_absolute_error ,squared_hinge,hinge,binary_crossentropy,categorical_crossentropy等,其中binary_crossentropy,categorical_crossentropy是指logloss
注:目标函数的设定是在模型编译阶段
详细说明:http://keras.io/objectives/
eg:
1 | model. compile (loss = 'binary_crossentropy' , optimizer = 'sgd' ) #loss是指目标函数 |
6. Optimizers
包:keras.optimizers
该模块主要负责设定神经网络的优化方法,如sgd。
注:优化函数的设定是在模型编译阶段
详细说明:http://keras.io/optimizers/
eg:
1 | model. compile (loss = 'binary_crossentropy' , optimizer = 'sgd' ) #optimizer是指优化方法 |
7. Preprocessing
包:keras.preprocessing.(image\sequence\text)
数据预处理模块,不过本人目前尚未用过
8. metrics
包:keras.metrics
与sklearn中metrics包基本相同,主要包含一些如binary_accuracy、mae、mse等的评价方法
eg:
1 2 | predict = model.predict_classes(test_x) #输出预测结果 keras.metrics.binary_accuracy(test_y,predict) #计算预测精度 |
参考链接:https://www.cnblogs.com/eniac1946/p/7424737.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 上周热点回顾(2.24-3.2)