用python编写BP神经网络

 1 #神经网络至少有三个函数
 2 #1. 初始化函数 ————设定输入节点、隐藏层节点、输出层节点的数量
 3 #2. 训练      ————学习给定训练集样本后,优化权重
 4 #3. 查询      ————给定输入,从输出节点给出答案
 5 
 6 #神经网络类的定义
 7 import numpy 
 8 import scipy.special
 9 
10 __all__ = ["NeuralNetwork"]
11 
12 class Neural_network(object):
13 
14 # 初始化神经网络
15 def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate):
16 #设置输入节点、隐藏节点、输出节点的个数
17 self.inodes = input_nodes
18 self.hnodes = hidden_nodes
19 self.onodes = output_nodes
20 #设置初始权重(参数:中心值,标准方差,矩阵大小)
21 self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
22 self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
23 #设置学习率
24 self.lr = learning_rate
25 
26 #设置sigmoid(x)为激活函数
27 self.activation_function = lambda x: scipy.special.expit(x)
28 
29 #训练
30 def train(self, input_list, targets_list):
31 #转化输入列表为数组
32 inputs = numpy.array(input_list, ndmin=2).T
33 #转化真值列表为数组
34 targets = numpy.array(targets_list, ndmin=2).T
35 #计算隐藏层输入
36 hidden_inputs = numpy.dot(self.wih, inputs)
37 #计算隐藏层输出
38 hidden_outputs = self.activation_function(hidden_inputs)
39 #计算输出层输入
40 final_inputs = numpy.dot(self.who, hidden_outputs)
41 #计算输出最终结果
42 final_outputs = self.activation_function(final_inputs)
43 #计算误差
44 output_errors = targets - final_outputs
45 #计算隐藏层误差数组
46 hidden_errors = numpy.dot(self.who.T, output_errors)
47 #更新权重--隐藏层、输出层
48 self.who += self.lr * numpy.dot((output_errors * final_outputs * (1 - final_outputs)), numpy.transpose(hidden_outputs))
49 #更新权重--输入层、隐藏层
50 self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1 - hidden_outputs)), numpy.transpose(inputs) )
51 #查询
52 def query(self, input_list):
53 #设置输入列表
54 inputs = numpy.array(input_list, ndmin=2).T
55 #计算隐藏层输入
56 hidden_inputs = numpy.dot(self.wih, inputs)
57 #计算隐藏层输出
58 hidden_outputs = self.activation_function(hidden_inputs)
59 #计算输出层输入
60 final_inputs = numpy.dot(self.who, hidden_outputs)
61 #计算输出最终结果
62 final_outputs = self.activation_function(final_inputs)
63 return final_outputs

 

posted @ 2018-11-20 19:17  溜肉段小能手  阅读(496)  评论(0编辑  收藏  举报