一杯清酒邀明月
天下本无事,庸人扰之而烦耳。

注:这里的练习鉴于当时理解不完全,可能会有些错误,关于神经网络的实践可以参考我的这篇博文

这里的代码只是简单的练习,不涉及代码优化,也不涉及神经网络优化,所以我用了最能体现原理的方式来写的代码。

激活函数用的是h = 1/(1+exp(-y)),其中y=sum([X Y].*w)。

代价函数用的是E = 1/2*(t-h)^2,其中t为目标值,t为1代表是该类,t为0代表不是该类。

权值更新采用BP算法。

网络1形式如下,没有隐含层,1个偏置量,输入直接连接输出:

分类结果:

代码如下:

  1 clear all;
  2 close all;
  3 clc;
  4 
  5 n=5;
  6 randn('seed',1);
  7 mu1=[0 0];
  8 S1=[0.5 0;
  9     0 0.5];
 10 P1=mvnrnd(mu1,S1,n);
 11 
 12 mu2=[0 6];
 13 S2=[0.5 0;
 14     0 0.5];
 15 P2=mvnrnd(mu2,S2,n);
 16 
 17 mu3=[6 3];
 18 S3=[0.5 0;
 19     0 0.5];
 20 P3=mvnrnd(mu3,S3,n);
 21 
 22 
 23 P=[P1;P2;P3];
 24 meanP=mean(P);
 25 
 26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
 27 
 28 sigma = 5;
 29 
 30 X=P(:,1);
 31 Y=P(:,2);
 32 B=rand(3*n,1);
 33 
 34 w1 = rand(3*n,1);
 35 w2 = rand(3*n,1);
 36 w3 = rand(3*n,1);
 37 
 38 w4 = rand(3*n,1);
 39 w5 = rand(3*n,1);
 40 w6 = rand(3*n,1);
 41 
 42 
 43 for i=1:3*n
 44     i
 45     while 1
 46         
 47         y1 = X(i)*w1(i) + Y(i)*w4(i) + B(i);       
 48         y2 = X(i)*w2(i) + Y(i)*w5(i) + B(i);        
 49         y3 = X(i)*w3(i) + Y(i)*w6(i) + B(i);     
 50         
 51         h1 = 1/(1+exp(-y1));
 52         h2 = 1/(1+exp(-y2));       
 53         h3 = 1/(1+exp(-y3));      
 54         
 55         e1  = 1/2*(1 - h1)^2;
 56         e2  = 1/2*(1 - h2)^2;       
 57         e3  = 1/2*(1 - h3)^2;
 58  
 59         if i<=n && e1<=0.0000001
 60             break;
 61         elseif i>n && i<=2*n && e2<0.0000001
 62             break;
 63         elseif i>2*n && e3<0.0000001
 64             break;
 65         end
 66         
 67         
 68         if i<=n
 69             w1(i) = w1(i)-sigma*(h1-1)*h1*(1-h1)*X(i);
 70             w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
 71             w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);    
 72             
 73             w4(i) = w4(i)-sigma*(h1-1)*h1*(1-h1)*Y(i);
 74             w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
 75             w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);                   
 76             
 77             B(i) =B(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));
 78         elseif i>n && i<=2*n
 79             w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
 80             w2(i) = w2(i)-sigma*(h2-1)*h2*(1-h2)*X(i);
 81             w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);    
 82             
 83             w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
 84             w5(i) = w5(i)-sigma*(h2-1)*h2*(1-h2)*Y(i);
 85             w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);                   
 86             
 87             B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));         
 88         else
 89             w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
 90             w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
 91             w3(i) = w3(i)-sigma*(h3-1)*h3*(1-h3)*X(i);    
 92             
 93             w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
 94             w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
 95             w6(i) = w6(i)-sigma*(h3-1)*h3*(1-h3)*Y(i);                   
 96             
 97             B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));                   
 98         end
 99          
100 
101     end
102 end
103 
104 plot(P(:,1),P(:,2),'o');
105 hold on;
106 
107 flag = 0;
108 M=[];
109 for x=-8:0.3:8
110     for y=-8:0.3:8
111 
112         H=[]; 
113         for i=1:3*n
114             y1 = x*w1(i)+y*w4(i) +B(i);
115             y2 = x*w2(i)+y*w5(i) +B(i);
116             y3 = x*w3(i)+y*w6(i) +B(i);
117             h1=1/(1+exp(-y1));
118             h2=1/(1+exp(-y2));
119             h3=1/(1+exp(-y3));
120             
121             H=[H;h1 h2 h3];
122         end
123   %      H1 = mean(H(1:n,1));
124   %      H2 = mean(H(n:2*n,2));
125   %      H3 = mean(H(2*n:3*n,3));
126         
127         meanH = mean(H);
128         H1 = meanH(1);
129         H2 = meanH(2);
130         H3= meanH(3);
131         if H1>H2 && H1>H3
132             plot(x,y,'g.')
133         elseif H2 > H1 && H2 > H3
134             plot(x,y,'r.')
135         elseif H3 > H1 && H3 > H2
136             plot(x,y,'b.')
137         end
138         
139     end
140 end

网络2形式如下,有1个隐含层,2个偏置量:

 

分类结果:

代码如下:

  1 clear all;
  2 close all;
  3 clc;
  4 
  5 n=5;
  6 randn('seed',1);
  7 mu1=[0 0];
  8 S1=[0.5 0;
  9     0 0.5];
 10 P1=mvnrnd(mu1,S1,n);
 11 
 12 mu2=[0 6];
 13 S2=[0.5 0;
 14     0 0.5];
 15 P2=mvnrnd(mu2,S2,n);
 16 
 17 mu3=[6 3];
 18 S3=[0.5 0;
 19     0 0.5];
 20 P3=mvnrnd(mu3,S3,n);
 21 
 22 
 23 P=[P1;P2;P3];
 24 meanP=mean(P);
 25 
 26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
 27 
 28 sigma = 5;
 29 
 30 X=P(:,1);
 31 Y=P(:,2);
 32 
 33 B1=rand(3*n,1);
 34 B2=rand(3*n,1);
 35 
 36 w1 = rand(3*n,1);
 37 w2 = rand(3*n,1);
 38 
 39 w3 = rand(3*n,1);
 40 w4 = rand(3*n,1);
 41 w5 = rand(3*n,1);
 42 
 43 for i=1:3*n
 44     i
 45     while 1
 46         
 47         y0 = X(i)*w1(i) + Y(i)*w2(i) + B1(i);  
 48         h0 = 1/(1+exp(-y0));  
 49               
 50         y1 = h0*w3(i) + B2(i);        
 51         y2 = h0*w4(i) + B2(i);     
 52         y3 = h0*w5(i) + B2(i);
 53         
 54         h1 = 1/(1+exp(-y1));       
 55         h2 = 1/(1+exp(-y2));      
 56         h3 = 1/(1+exp(-y3));
 57         
 58         e1  = 1/2*(1 - h1)^2;
 59         e2  = 1/2*(1 - h2)^2;       
 60         e3  = 1/2*(1 - h3)^2;
 61  
 62         if i<=n && e1<=0.0000001
 63             break;
 64         elseif i>n && i<=2*n && e2<0.0000001
 65             break;
 66         elseif i>2*n && e3<0.0000001
 67             break;
 68         end
 69                
 70         %e1
 71         if i<=n
 72             
 73             w1(i) = w1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
 74             w2(i) = w2(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
 75             B1(i) = B1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
 76             
 77             w3(i) = w3(i)-sigma*(h1-1)*h1*(1-h1)*h0;              
 78             w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
 79             w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
 80             B2(i) =B2(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));   
 81                           
 82         elseif i>n && i<=2*n
 83             w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
 84             w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
 85             B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
 86             
 87             w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;              
 88             w4(i) = w4(i)-sigma*(h2-1)*h2*(1-h2)*h0;
 89             w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
 90             B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));   
 91                      
 92         else
 93             w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));      
 94             w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));           
 95             B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)      + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)      + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0));
 96           
 97             w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;              
 98             w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
 99             w5(i) = w5(i)-sigma*(h3-1)*h3*(1-h3)*h0;
100             B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));   
101                              
102         end
103          
104 
105     end
106 end
107 
108 
109 plot(P(:,1),P(:,2),'o');
110 hold on;
111 
112 flag = 0;
113 M=[];
114 for x=-8:0.3:8
115     for y=-8:0.3:8
116   
117        H=[]; 
118         for i=1:3*n
119             y0 = x*w1(i)+y*w2(i) +B1(i);
120             h0=1/(1+exp(-y0));     
121             
122             y1 = h0*w3(i) + B2(i);
123             y2 = h0*w4(i) + B2(i);
124             y3 = h0*w5(i) + B2(i);
125 
126             h1 =1/(1+exp(-y1));
127             h2 =1/(1+exp(-y2));
128             h3 =1/(1+exp(-y3));
129             
130             H=[H;h1 h2 h3];
131         end
132 
133         meanH = mean(H);
134        H1 = meanH(1);
135         H2 = meanH(2);
136        H3= meanH(3);
137         if H1>H2 && H1>H3
138             plot(x,y,'g.')
139         elseif H2 > H1 && H2 > H3
140             plot(x,y,'r.')
141         elseif H3 > H1 && H3 > H2
142             plot(x,y,'b.')
143         end
144         
145     end
146 end

网络3形式如下,有2个隐含层,2个偏置量:

 

 

分类结果:

代码如下:

  1 clear all;
  2 close all;
  3 clc;
  4 
  5 n=5;
  6 randn('seed',1);
  7 mu1=[0 0];
  8 S1=[0.5 0;
  9     0 0.5];
 10 P1=mvnrnd(mu1,S1,n);
 11 
 12 mu2=[0 6];
 13 S2=[0.5 0;
 14     0 0.5];
 15 P2=mvnrnd(mu2,S2,n);
 16 
 17 mu3=[6 3];
 18 S3=[0.5 0;
 19     0 0.5];
 20 P3=mvnrnd(mu3,S3,n);
 21 
 22 
 23 P=[P1;P2;P3];
 24 meanP=mean(P);
 25 
 26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
 27 
 28 sigma = 20;
 29 
 30 X=P(:,1);
 31 Y=P(:,2);
 32 
 33 B1=rand(3*n,1);
 34 B2=rand(3*n,1);
 35 
 36 w1 = rand(3*n,1);
 37 w2 = rand(3*n,1);
 38 
 39 w3 = rand(3*n,1);
 40 w4 = rand(3*n,1);
 41 
 42 w5 = rand(3*n,1);
 43 w6 = rand(3*n,1);
 44 w7 = rand(3*n,1);
 45 
 46 w8 = rand(3*n,1);
 47 w9 = rand(3*n,1);
 48 w10 = rand(3*n,1);
 49 
 50 for i=1:3*n
 51     i
 52     while 1
 53         
 54         y1 = X(i)*w1(i) + Y(i)*w3(i) + B1(i);
 55         y2 = X(i)*w2(i) + Y(i)*w4(i) + B1(i);
 56         
 57         h1 = 1/(1+exp(-y1));  
 58         h2 = 1/(1+exp(-y2));        
 59         
 60         dh1 = h1*(1-h1);
 61         dh2 = h2*(1-h2);
 62         
 63         y3 = h1*w5(i) + h2*w8(i)+ B2(i);        
 64         y4 = h1*w6(i) + h2*w9(i)+ B2(i);      
 65         y5 = h1*w7(i) + h2*w10(i)+ B2(i);    
 66         
 67         h3 = 1/(1+exp(-y3));       
 68         h4 = 1/(1+exp(-y4));      
 69         h5 = 1/(1+exp(-y5));
 70         
 71         dh3 = h3*(1-h3);
 72         dh4 = h4*(1-h4);
 73         dh5 = h5*(1-h5);
 74         
 75         e1  = 1/2*(1 - h3)^2;
 76         e2  = 1/2*(1 - h4)^2;       
 77         e3  = 1/2*(1 - h5)^2;
 78  
 79         if i<=n && e1<=0.0000001
 80             break;
 81         elseif i>n && i<=2*n && e2<0.0000001
 82             break;
 83         elseif i>2*n && e3<0.0000001
 84             break;
 85         end
 86                
 87         %e1
 88         if i<=n
 89             
 90             w1(i) = w1(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*X(i);
 91             w2(i) = w2(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);          
 92             
 93             w3(i) = w3(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*Y(i);
 94             w4(i) = w4(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);       
 95                      
 96             B1(i) = B1(i)- sigma*(((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
 97             
 98             w5(i) = w5(i)-sigma*(h3-1)*dh3*h1;              
 99             w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
100             w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
101             
102             w8(i) = w8(i)-sigma*(h3-1)*dh3*h2;              
103             w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
104             w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;         
105             
106             B2(i) =B2(i)- sigma*((h3-1)*dh3+(h4-0)*dh4+(h5-0)*dh5);   
107                           
108         elseif i>n && i<=2*n
109             w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*X(i);
110             w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);          
111             
112             w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))  * dh1*Y(i);
113             w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);       
114                      
115             B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
116             
117             w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;              
118             w6(i) = w6(i)-sigma*(h4-1)*dh4*h1;
119             w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
120             
121             w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;              
122             w9(i) = w9(i)-sigma*(h4-1)*dh4*h2;
123             w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;         
124             
125             B2(i) =B2(i)- sigma*((h3-0)*dh3+(h4-1)*dh4+(h5-0)*dh5);   
126                      
127         else
128             w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))  * dh1*X(i);
129             w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*X(i);          
130             
131             w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))  * dh1*Y(i);
132             w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*Y(i);       
133                      
134             B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i))*dh2);
135             
136             w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;              
137             w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
138             w7(i) = w7(i)-sigma*(h5-1)*dh5*h1;
139             
140             w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;              
141             w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
142             w10(i) = w10(i)-sigma*(h5-1)*dh5*h2;         
143             
144             B2(i) =B2(i)- sigma*((h3-0)*dh3+(h4-0)*dh4+(h5-1)*dh5);  
145                              
146         end
147          
148 
149     end
150 end
151 
152 
153 plot(P(:,1),P(:,2),'o');
154 hold on;
155 
156 flag = 0;
157 M=[];
158 for x=-8:0.3:8
159     for y=-8:0.3:8
160    %     x=-1;
161    %     y=2;
162         H=[]; 
163         for i=1:3*n
164             y1 = x*w1(i) + y*w3(i) + B1(i);
165             y2 = x*w2(i) + y*w4(i) + B1(i);
166 
167             h1 = 1/(1+exp(-y1));  
168             h2 = 1/(1+exp(-y2));        
169 
170             dh1 = h1*(1-h1);
171             dh2 = h2*(1-h2);
172 
173             y3 = h1*w5(i) + h2*w8(i)+ B2(i);        
174             y4 = h1*w6(i) + h2*w9(i)+ B2(i);      
175             y5 = h1*w7(i) + h2*w10(i)+ B2(i);    
176 
177             h3 = 1/(1+exp(-y3));       
178             h4 = 1/(1+exp(-y4));      
179             h5 = 1/(1+exp(-y5));
180             
181             H=[H;h3 h4 h5];
182         end
183     %    H1 = mean(H(1:n,1));
184     %    H2 = mean(H(n+1:2*n,2));
185     %    H3 = mean(H(2*n+1:3*n,3));
186    
187         meanH = mean(H);
188         H1 = meanH(1);
189         H2 = meanH(2);
190         H3= meanH(3);
191         
192         M=[M;H1 H2 H3 x y];     
193         if H1>H2 && H1>H3
194             plot(x,y,'g.')
195         elseif H2 > H1 && H2 > H3
196             plot(x,y,'r.')
197         elseif H3 > H1 && H3 > H2
198             plot(x,y,'b.')
199         end
200         
201     end
202 end

后面我计划对网络分别使用softmax,权重初始化,正则化,ReLu激活函数,交叉熵代价函数与卷积的形式进行优化。 

posted on 2020-09-10 14:57  一杯清酒邀明月  阅读(1020)  评论(0编辑  收藏  举报