注:这里的练习鉴于当时理解不完全,可能会有些错误,关于神经网络的实践可以参考我的这篇博文
这里的代码只是简单的练习,不涉及代码优化,也不涉及神经网络优化,所以我用了最能体现原理的方式来写的代码。
激活函数用的是h = 1/(1+exp(-y)),其中y=sum([X Y].*w)。
代价函数用的是E = 1/2*(t-h)^2,其中t为目标值,t为1代表是该类,t为0代表不是该类。
权值更新采用BP算法。
网络1形式如下,没有隐含层,1个偏置量,输入直接连接输出:
分类结果:
代码如下:
1 clear all;
2 close all;
3 clc;
4
5 n=5;
6 randn('seed',1);
7 mu1=[0 0];
8 S1=[0.5 0;
9 0 0.5];
10 P1=mvnrnd(mu1,S1,n);
11
12 mu2=[0 6];
13 S2=[0.5 0;
14 0 0.5];
15 P2=mvnrnd(mu2,S2,n);
16
17 mu3=[6 3];
18 S3=[0.5 0;
19 0 0.5];
20 P3=mvnrnd(mu3,S3,n);
21
22
23 P=[P1;P2;P3];
24 meanP=mean(P);
25
26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
27
28 sigma = 5;
29
30 X=P(:,1);
31 Y=P(:,2);
32 B=rand(3*n,1);
33
34 w1 = rand(3*n,1);
35 w2 = rand(3*n,1);
36 w3 = rand(3*n,1);
37
38 w4 = rand(3*n,1);
39 w5 = rand(3*n,1);
40 w6 = rand(3*n,1);
41
42
43 for i=1:3*n
44 i
45 while 1
46
47 y1 = X(i)*w1(i) + Y(i)*w4(i) + B(i);
48 y2 = X(i)*w2(i) + Y(i)*w5(i) + B(i);
49 y3 = X(i)*w3(i) + Y(i)*w6(i) + B(i);
50
51 h1 = 1/(1+exp(-y1));
52 h2 = 1/(1+exp(-y2));
53 h3 = 1/(1+exp(-y3));
54
55 e1 = 1/2*(1 - h1)^2;
56 e2 = 1/2*(1 - h2)^2;
57 e3 = 1/2*(1 - h3)^2;
58
59 if i<=n && e1<=0.0000001
60 break;
61 elseif i>n && i<=2*n && e2<0.0000001
62 break;
63 elseif i>2*n && e3<0.0000001
64 break;
65 end
66
67
68 if i<=n
69 w1(i) = w1(i)-sigma*(h1-1)*h1*(1-h1)*X(i);
70 w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
71 w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);
72
73 w4(i) = w4(i)-sigma*(h1-1)*h1*(1-h1)*Y(i);
74 w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
75 w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);
76
77 B(i) =B(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));
78 elseif i>n && i<=2*n
79 w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
80 w2(i) = w2(i)-sigma*(h2-1)*h2*(1-h2)*X(i);
81 w3(i) = w3(i)-sigma*(h3-0)*h3*(1-h3)*X(i);
82
83 w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
84 w5(i) = w5(i)-sigma*(h2-1)*h2*(1-h2)*Y(i);
85 w6(i) = w6(i)-sigma*(h3-0)*h3*(1-h3)*Y(i);
86
87 B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));
88 else
89 w1(i) = w1(i)-sigma*(h1-0)*h1*(1-h1)*X(i);
90 w2(i) = w2(i)-sigma*(h2-0)*h2*(1-h2)*X(i);
91 w3(i) = w3(i)-sigma*(h3-1)*h3*(1-h3)*X(i);
92
93 w4(i) = w4(i)-sigma*(h1-0)*h1*(1-h1)*Y(i);
94 w5(i) = w5(i)-sigma*(h2-0)*h2*(1-h2)*Y(i);
95 w6(i) = w6(i)-sigma*(h3-1)*h3*(1-h3)*Y(i);
96
97 B(i) =B(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));
98 end
99
100
101 end
102 end
103
104 plot(P(:,1),P(:,2),'o');
105 hold on;
106
107 flag = 0;
108 M=[];
109 for x=-8:0.3:8
110 for y=-8:0.3:8
111
112 H=[];
113 for i=1:3*n
114 y1 = x*w1(i)+y*w4(i) +B(i);
115 y2 = x*w2(i)+y*w5(i) +B(i);
116 y3 = x*w3(i)+y*w6(i) +B(i);
117 h1=1/(1+exp(-y1));
118 h2=1/(1+exp(-y2));
119 h3=1/(1+exp(-y3));
120
121 H=[H;h1 h2 h3];
122 end
123 % H1 = mean(H(1:n,1));
124 % H2 = mean(H(n:2*n,2));
125 % H3 = mean(H(2*n:3*n,3));
126
127 meanH = mean(H);
128 H1 = meanH(1);
129 H2 = meanH(2);
130 H3= meanH(3);
131 if H1>H2 && H1>H3
132 plot(x,y,'g.')
133 elseif H2 > H1 && H2 > H3
134 plot(x,y,'r.')
135 elseif H3 > H1 && H3 > H2
136 plot(x,y,'b.')
137 end
138
139 end
140 end
网络2形式如下,有1个隐含层,2个偏置量:
分类结果:
代码如下:
1 clear all;
2 close all;
3 clc;
4
5 n=5;
6 randn('seed',1);
7 mu1=[0 0];
8 S1=[0.5 0;
9 0 0.5];
10 P1=mvnrnd(mu1,S1,n);
11
12 mu2=[0 6];
13 S2=[0.5 0;
14 0 0.5];
15 P2=mvnrnd(mu2,S2,n);
16
17 mu3=[6 3];
18 S3=[0.5 0;
19 0 0.5];
20 P3=mvnrnd(mu3,S3,n);
21
22
23 P=[P1;P2;P3];
24 meanP=mean(P);
25
26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
27
28 sigma = 5;
29
30 X=P(:,1);
31 Y=P(:,2);
32
33 B1=rand(3*n,1);
34 B2=rand(3*n,1);
35
36 w1 = rand(3*n,1);
37 w2 = rand(3*n,1);
38
39 w3 = rand(3*n,1);
40 w4 = rand(3*n,1);
41 w5 = rand(3*n,1);
42
43 for i=1:3*n
44 i
45 while 1
46
47 y0 = X(i)*w1(i) + Y(i)*w2(i) + B1(i);
48 h0 = 1/(1+exp(-y0));
49
50 y1 = h0*w3(i) + B2(i);
51 y2 = h0*w4(i) + B2(i);
52 y3 = h0*w5(i) + B2(i);
53
54 h1 = 1/(1+exp(-y1));
55 h2 = 1/(1+exp(-y2));
56 h3 = 1/(1+exp(-y3));
57
58 e1 = 1/2*(1 - h1)^2;
59 e2 = 1/2*(1 - h2)^2;
60 e3 = 1/2*(1 - h3)^2;
61
62 if i<=n && e1<=0.0000001
63 break;
64 elseif i>n && i<=2*n && e2<0.0000001
65 break;
66 elseif i>2*n && e3<0.0000001
67 break;
68 end
69
70 %e1
71 if i<=n
72
73 w1(i) = w1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));
74 w2(i) = w2(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));
75 B1(i) = B1(i)- sigma*((h1-1)*h1*(1-h1)*w3(i)*h0*(1-h0) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
76
77 w3(i) = w3(i)-sigma*(h1-1)*h1*(1-h1)*h0;
78 w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
79 w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
80 B2(i) =B2(i)- sigma*((h1-1)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-0)*h3*(1-h3));
81
82 elseif i>n && i<=2*n
83 w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));
84 w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));
85 B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0) + (h2-1)*h2*(1-h2)*w4(i)*h0*(1-h0) + (h3-0)*h3*(1-h3)*w5(i)*h0*(1-h0));
86
87 w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;
88 w4(i) = w4(i)-sigma*(h2-1)*h2*(1-h2)*h0;
89 w5(i) = w5(i)-sigma*(h3-0)*h3*(1-h3)*h0;
90 B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-1)*h2*(1-h2)+(h3-0)*h3*(1-h3));
91
92 else
93 w1(i) = w1(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*X(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*X(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*X(i));
94 w2(i) = w2(i)-sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0)*Y(i) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0)*Y(i) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0)*Y(i));
95 B1(i) =B1(i)- sigma*((h1-0)*h1*(1-h1)*w3(i)*h0*(1-h0) + (h2-0)*h2*(1-h2)*w4(i)*h0*(1-h0) + (h3-1)*h3*(1-h3)*w5(i)*h0*(1-h0));
96
97 w3(i) = w3(i)-sigma*(h1-0)*h1*(1-h1)*h0;
98 w4(i) = w4(i)-sigma*(h2-0)*h2*(1-h2)*h0;
99 w5(i) = w5(i)-sigma*(h3-1)*h3*(1-h3)*h0;
100 B2(i) =B2(i)- sigma*((h1-0)*h1*(1-h1)+(h2-0)*h2*(1-h2)+(h3-1)*h3*(1-h3));
101
102 end
103
104
105 end
106 end
107
108
109 plot(P(:,1),P(:,2),'o');
110 hold on;
111
112 flag = 0;
113 M=[];
114 for x=-8:0.3:8
115 for y=-8:0.3:8
116
117 H=[];
118 for i=1:3*n
119 y0 = x*w1(i)+y*w2(i) +B1(i);
120 h0=1/(1+exp(-y0));
121
122 y1 = h0*w3(i) + B2(i);
123 y2 = h0*w4(i) + B2(i);
124 y3 = h0*w5(i) + B2(i);
125
126 h1 =1/(1+exp(-y1));
127 h2 =1/(1+exp(-y2));
128 h3 =1/(1+exp(-y3));
129
130 H=[H;h1 h2 h3];
131 end
132
133 meanH = mean(H);
134 H1 = meanH(1);
135 H2 = meanH(2);
136 H3= meanH(3);
137 if H1>H2 && H1>H3
138 plot(x,y,'g.')
139 elseif H2 > H1 && H2 > H3
140 plot(x,y,'r.')
141 elseif H3 > H1 && H3 > H2
142 plot(x,y,'b.')
143 end
144
145 end
146 end
网络3形式如下,有2个隐含层,2个偏置量:
分类结果:
代码如下:
1 clear all;
2 close all;
3 clc;
4
5 n=5;
6 randn('seed',1);
7 mu1=[0 0];
8 S1=[0.5 0;
9 0 0.5];
10 P1=mvnrnd(mu1,S1,n);
11
12 mu2=[0 6];
13 S2=[0.5 0;
14 0 0.5];
15 P2=mvnrnd(mu2,S2,n);
16
17 mu3=[6 3];
18 S3=[0.5 0;
19 0 0.5];
20 P3=mvnrnd(mu3,S3,n);
21
22
23 P=[P1;P2;P3];
24 meanP=mean(P);
25
26 P=[P(:,1)-meanP(1) P(:,2)-meanP(2)];
27
28 sigma = 20;
29
30 X=P(:,1);
31 Y=P(:,2);
32
33 B1=rand(3*n,1);
34 B2=rand(3*n,1);
35
36 w1 = rand(3*n,1);
37 w2 = rand(3*n,1);
38
39 w3 = rand(3*n,1);
40 w4 = rand(3*n,1);
41
42 w5 = rand(3*n,1);
43 w6 = rand(3*n,1);
44 w7 = rand(3*n,1);
45
46 w8 = rand(3*n,1);
47 w9 = rand(3*n,1);
48 w10 = rand(3*n,1);
49
50 for i=1:3*n
51 i
52 while 1
53
54 y1 = X(i)*w1(i) + Y(i)*w3(i) + B1(i);
55 y2 = X(i)*w2(i) + Y(i)*w4(i) + B1(i);
56
57 h1 = 1/(1+exp(-y1));
58 h2 = 1/(1+exp(-y2));
59
60 dh1 = h1*(1-h1);
61 dh2 = h2*(1-h2);
62
63 y3 = h1*w5(i) + h2*w8(i)+ B2(i);
64 y4 = h1*w6(i) + h2*w9(i)+ B2(i);
65 y5 = h1*w7(i) + h2*w10(i)+ B2(i);
66
67 h3 = 1/(1+exp(-y3));
68 h4 = 1/(1+exp(-y4));
69 h5 = 1/(1+exp(-y5));
70
71 dh3 = h3*(1-h3);
72 dh4 = h4*(1-h4);
73 dh5 = h5*(1-h5);
74
75 e1 = 1/2*(1 - h3)^2;
76 e2 = 1/2*(1 - h4)^2;
77 e3 = 1/2*(1 - h5)^2;
78
79 if i<=n && e1<=0.0000001
80 break;
81 elseif i>n && i<=2*n && e2<0.0000001
82 break;
83 elseif i>2*n && e3<0.0000001
84 break;
85 end
86
87 %e1
88 if i<=n
89
90 w1(i) = w1(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i)) * dh1*X(i);
91 w2(i) = w2(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);
92
93 w3(i) = w3(i) -sigma * ((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i)) * dh1*Y(i);
94 w4(i) = w4(i) -sigma * ((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);
95
96 B1(i) = B1(i)- sigma*(((h3-1)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-1)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
97
98 w5(i) = w5(i)-sigma*(h3-1)*dh3*h1;
99 w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
100 w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
101
102 w8(i) = w8(i)-sigma*(h3-1)*dh3*h2;
103 w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
104 w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;
105
106 B2(i) =B2(i)- sigma*((h3-1)*dh3+(h4-0)*dh4+(h5-0)*dh5);
107
108 elseif i>n && i<=2*n
109 w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i)) * dh1*X(i);
110 w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*X(i);
111
112 w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i)) * dh1*Y(i);
113 w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i)) * dh2*Y(i);
114
115 B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-1)*dh4*w6(i)+(h5-0)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-1)*dh4*w9(i)+(h5-0)*dh5*w10(i))*dh2);
116
117 w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;
118 w6(i) = w6(i)-sigma*(h4-1)*dh4*h1;
119 w7(i) = w7(i)-sigma*(h5-0)*dh5*h1;
120
121 w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;
122 w9(i) = w9(i)-sigma*(h4-1)*dh4*h2;
123 w10(i) = w10(i)-sigma*(h5-0)*dh5*h2;
124
125 B2(i) =B2(i)- sigma*((h3-0)*dh3+(h4-1)*dh4+(h5-0)*dh5);
126
127 else
128 w1(i) = w1(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i)) * dh1*X(i);
129 w2(i) = w2(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*X(i);
130
131 w3(i) = w3(i) -sigma * ((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i)) * dh1*Y(i);
132 w4(i) = w4(i) -sigma * ((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i)) * dh2*Y(i);
133
134 B1(i) = B1(i)- sigma*(((h3-0)*dh3*w5(i)+(h4-0)*dh4*w6(i)+(h5-1)*dh5*w7(i))*dh1+((h3-0)*dh3*w8(i)+(h4-0)*dh4*w9(i)+(h5-1)*dh5*w10(i))*dh2);
135
136 w5(i) = w5(i)-sigma*(h3-0)*dh3*h1;
137 w6(i) = w6(i)-sigma*(h4-0)*dh4*h1;
138 w7(i) = w7(i)-sigma*(h5-1)*dh5*h1;
139
140 w8(i) = w8(i)-sigma*(h3-0)*dh3*h2;
141 w9(i) = w9(i)-sigma*(h4-0)*dh4*h2;
142 w10(i) = w10(i)-sigma*(h5-1)*dh5*h2;
143
144 B2(i) =B2(i)- sigma*((h3-0)*dh3+(h4-0)*dh4+(h5-1)*dh5);
145
146 end
147
148
149 end
150 end
151
152
153 plot(P(:,1),P(:,2),'o');
154 hold on;
155
156 flag = 0;
157 M=[];
158 for x=-8:0.3:8
159 for y=-8:0.3:8
160 % x=-1;
161 % y=2;
162 H=[];
163 for i=1:3*n
164 y1 = x*w1(i) + y*w3(i) + B1(i);
165 y2 = x*w2(i) + y*w4(i) + B1(i);
166
167 h1 = 1/(1+exp(-y1));
168 h2 = 1/(1+exp(-y2));
169
170 dh1 = h1*(1-h1);
171 dh2 = h2*(1-h2);
172
173 y3 = h1*w5(i) + h2*w8(i)+ B2(i);
174 y4 = h1*w6(i) + h2*w9(i)+ B2(i);
175 y5 = h1*w7(i) + h2*w10(i)+ B2(i);
176
177 h3 = 1/(1+exp(-y3));
178 h4 = 1/(1+exp(-y4));
179 h5 = 1/(1+exp(-y5));
180
181 H=[H;h3 h4 h5];
182 end
183 % H1 = mean(H(1:n,1));
184 % H2 = mean(H(n+1:2*n,2));
185 % H3 = mean(H(2*n+1:3*n,3));
186
187 meanH = mean(H);
188 H1 = meanH(1);
189 H2 = meanH(2);
190 H3= meanH(3);
191
192 M=[M;H1 H2 H3 x y];
193 if H1>H2 && H1>H3
194 plot(x,y,'g.')
195 elseif H2 > H1 && H2 > H3
196 plot(x,y,'r.')
197 elseif H3 > H1 && H3 > H2
198 plot(x,y,'b.')
199 end
200
201 end
202 end
后面我计划对网络分别使用softmax,权重初始化,正则化,ReLu激活函数,交叉熵代价函数与卷积的形式进行优化。