基于BP神经网络的手写数字识别
一.BP神经网络原理及结构
本片博客偏向于BP神经网络的MATLAB程序实现讲解,详细原理请参考:http://www.cnblogs.com/wentingtu/archive/2012/06/05/2536425.html
1.神经元
神经元作为神经网络的基本单元,对于外界输入具有简单的反应能力,在数学上表征为简单的函数映射。如下图是一个神经元的基本结构,
神经元结构
图中是神经元的输入,是神经元输入的权重,是神经元的激活函数,y是神经元的输出,其函数映射关系为
激活函数来描述层与层输出之间的关系,从而模拟各层神经元之间的交互反应。激活函数必须满足处处可导的条件。常用的神经元函数有四种,分别是线性函数、指示函数、Sigmoid函数(S型函数)、径向基函数RBF(Radial Basis Function)。本次仿真使用Sigmoid函数,其表达式为
其导数为
其相应的曲线如图2-2
图2-2 Sigmoid函数
2.BP神经网络
BP神经网络是前向多层神经网络(Multilayer Feedforward Neural Network, MLFN),其每一层包含若干个神经元,同一层的神经元之间没有互相联系,层间信息的传送只沿一个方向进行。主要包括输入层、隐含层和输出层,其中隐含层的数目随精确度的要求而有所变化。BP网络能学习和存贮大量的输入-输出模式映射关系,而无需事前揭示描述这种映射关系的数学方程。建立模型之后,就可以通过减少输出误差来训练网络中的参数,从而逼近正确的函数模型。
BP神经网络主要分为信号的正向传播和误差反向传播两个过程,正向传播是指在样本数据输入后,计算输入层到隐含层,再到输出层的数据;误差反向传播就是将输出误差以某种形式通过隐含层向输入层逐层反传,将误差分摊给各层的单元,再使用梯度下降法修正各单元的权值。一般当网络输出的误差减少到可接受的程度或者进行到预先设定的学习次数时,训练结束。
这里以下图2-2简单的三层BP神经网络为例:
图2-2三层BP神经网络
假设输入层有n个神经元,隐含层有p个神经元,输出层有q个神经元。定义变量:
输入向量:
隐含层的输入向量:
隐含层的输出向量:
输出层的输入:
输出层的输出:
期望输出向量:
输入层与隐含层的连接权值:
隐含层与输出成的连接权值:
样本数据个数:
输出误差函数:
前向传播计算与误差反向传播具体步骤:
第一步:会对网络进行初始化,对各连接权值分别赋一个区间(-1,1)内的随机数,设定误差函数e,给定计算精度值和最大学习次数M。
第二步:随机选取第k个输入样本及对应的期望输出
第三步:计算隐含层各神经元的输入和输出
第四步:利用网络期望输出和实际输出,计算误差函数对输出层的各神经元的偏导数
第五步:利用输出层各神经元的
第六步:计算全局误差
第七步:当第六步计算得到的全局误差小于可接受的程度,或训练进行到预先设定的学习次数时,训练结束。
通过上述的前向传播计算与误差反向传播,BP神经网络就可以实现对样本数据的有监督训练,得到可以实现数字图像识别的逼近函数以及必要的参数。
3.MATLAB仿真及结果
本次BP神经网络数字图像识别使用MATLAB语言进行数值仿真,在进行BP神经网络创建之前,首选需要进行数字图像的hash特征提取。获得样本图像的hash特征之后,在进行BP神经网络的创建、训练和测试。
特征提取
图像的特征对于图像的标识有着非常大的重要性,好的特征能够大大提高图像的识别率。
本次MATLAB仿真提取数字图像hash特征,具体步骤如下:
Hash特征提取步骤图
参数训练
将数字图像的hash特征作为网络的输入数据,按照第二节BP神经网络原理建立BP神经网络,定义各变量并设置各参数,再按照前向传播计算和误差反向传播的原理计算各层的值并迭代更新输入层和隐含层之间的权矩阵、隐含层和输出层之间的权矩阵。知道计算的全局误差小于允许误差或者训练达到预定次数。
参数测试
将测试图像的hash特征作为网络输入数据,计算输出,判断正确与否,并统计测试识别率。
仿真结果
由于BP神经网络的输入数据是的64维输入特征,故输入层神经元个数为65(加上偏置神经元),输出层神经元个数为10(只有一个位置输出为1,其他位置输出为0),设置的允许误差为0.000001,训练样本数为60000,测试样本数为10000,隐含层神经元个数在三次实验中分别为43、101、11个,相应的训练及测试实验结果如下。
输入层 |
隐含层 |
输出层 |
允许误差 |
输出全局误差 |
训练次数 |
识别率 |
65 |
43 |
10 |
0.000001 |
6.014958e-07 |
40816 |
92.9% |
65 |
101 |
10 |
0.000001 |
3.678565e-07 |
34499 |
92.0% |
65 |
11 |
10 |
0.000001 |
1.349866e-02 |
60000 |
89.95% |
BP神经网络仿真表格
附件为本次"基于BP神经网络的数字图像识别"的MATLAB程序,主要分为两个函数,GetFeature函数提取图像的hash特征,BpDigitsRecog函数创建、训练BP神经网络,并进行测试。训练样本数据和标签可在网盘http://pan.baidu.com/s/1bn4h0gZ下载。
提取特征:
%2015.4.13 by anchor %This function is to get hash feature of an image clc;clear all;close inputs = load('mnist_train.mat'); inputs = inputs.mnist_train; [sample_n] = size( inputs,1); for input_n = 1:sample_n hash_feature = inputs(input_n,:); image = reshape(hash_feature,28,28); [x,y]=find(image~=0); image=image(min(x):max(x),min(y):max(y)); image = imresize(image,[8 8]); hash_features(input_n,:) = image(:)'; end save train_hash hash_features clear hash_features inputs = load('mnist_test.mat'); inputs = inputs.mnist_test; [sample_n] = size( inputs,1); for input_n = 1:sample_n hash_feature = inputs(input_n,:); image = reshape(hash_feature,28,28); [x,y]=find(image~=0); image=image(min(x):max(x),min(y):max(y)); image = imresize(image,[8 8]); hash_features(input_n,:) = image(:)'; end save test_hash hash_features
BP主程序:
1 %2015.4.25 by anchor 2 % ============= Part 1: Input Images And Get Hash Features =============== 3 %This function GetFeature is to get hash feature of 28*28 image 4 %function GetFeatures 5 %%=================== Part 2: Create BPNN And Train it===================== 6 clc;clear all; 7 inputs = load('train_hash.mat'); 8 inputs = inputs.hash_features/256; 9 class = load('mnist_train_labels.mat'); 10 class = class.mnist_train_labels; 11 [length_input,feture_length]=size(inputs); 12 num_hide=10; 13 weight_input_hide=0.0001*(rand(feture_length,num_hide)); 14 weight_hide_output=0.0001*(rand(num_hide,10)); 15 input_hide_offset = ones(num_hide,1); 16 hide_out_offset = ones(10,1); 17 fun_hide=@sigmf; %%%a need to be modulate 18 fun_out=@sigmf; 19 learn_rate=0.3; 20 allow_error = 0.000001; 21 ideal_out_list = eye(10); 22 for input_n=1:length_input %show the progress once trained 2000 images 23 % ======Multilayer Feedforward Neural Network(MLFN)====== 24 input = inputs(input_n,:)'; 25 ideal_out = ideal_out_list(1:10,class(input_n)+1); 26 hide_input = weight_input_hide'*input+input_hide_offset; 27 hide_output = fun_hide(hide_input,[1 0]); 28 out_input = weight_hide_output'*hide_output+hide_out_offset; 29 out_out = fun_out(out_input,[1,0]); 30 out_err = 1/2*sum((ideal_out-out_out).^2); 31 % ======Back Propagation of Errot to modulate weight====== 32 out_delta=(ideal_out-out_out).*out_out.*(1-out_out); %%%the derived function need to be modulate 33 hide_out_offset =hide_out_offset + learn_rate*out_delta; 34 hide_delta = weight_hide_output*out_delta.*hide_output.*(1-hide_output); 35 weight_hide_output = weight_hide_output+learn_rate*hide_output*out_delta';%update weight_hide_output 36 weight_input_hide = weight_input_hide+learn_rate*input*hide_delta';%update weight_hide_output 37 input_hide_offset = input_hide_offset + learn_rate*hide_delta; 38 if out_err<=allow_error 39 break; 40 end 41 end 42 fprintf('The error is %d ,the iteration is %d\n',out_err,input_n); 43 save weight_input_hide weight_input_hide 44 save weight_hide_output weight_hide_output 45 % ======================== Part 3: Test The BPNN=========================== 46 inputs = load('test_hash.mat'); 47 inputs = inputs.hash_features/256; 48 class = load('mnist_test_labels.mat'); 49 class = class.mnist_test_labels; 50 [length_input,feture_length]=size(inputs); 51 correct = 0; 52 corrct_rate = zeros(1,length_input); 53 % ======Multilayer Feedforward Neural Network(MLFN)====== 54 for input_n=1:length_input 55 input = inputs(input_n,:)'; 56 ideal_out = class(input_n)+1; 57 hide_input = weight_input_hide'*input+input_hide_offset; 58 hide_output = fun_hide(hide_input,[1 0]); 59 out_input = weight_hide_output'*hide_output+hide_out_offset; 60 out_out = fun_out(out_input,[1,0]); 61 similar = sum((repmat(out_out,1,10)-ideal_out_list).^2); 62 output = find(similar == min(similar)); 63 if output == ideal_out 64 correct =correct +1; 65 end 66 corrct_rate(input_n) = correct/input_n; 67 end 68 hold on; 69 plot(corrct_rate,'y') 70 xlabel('test number'); 71 ylabel('corrct rate'); 72 fprintf('The recognition rate of test is %f%% \n',corrct_rate(input_n)*100);