感知机分类算法
1 function [test_targets, a] = Perceptron(train_patterns, train_targets, test_patterns, alg_param)
2
3 % Classify using the Perceptron algorithm (Fixed increment single-sample perceptron)
4 % Inputs:
5 % train_patterns - Train patterns
6 % train_targets - Train targets
7 % test_patterns - Test patterns
8 % alg_param - Either: Number of iterations, weights vector or [weights, number of iterations]
9 % Outputs
10 % test_targets - Predicted targets
11 % a - Perceptron weights
12 % NOTE: Works for only two classes.
13 % 测试用法
14 % train_patterns=[-0.5,-0.5,0.3,0.1,-0.1,0.8,0.2,0.3;
15 % 0.3,-0.2,-0.6,0.1,-0.5,1.0,0.3,0.9];
16 % train_targets= [0,0,0,1,0,1,1,1];
17 % test_patterns=[0.2 -0.3]' %输出0
18 % % alg_param=100;
19 % alg_param=[0.01;0.01;0.0;0.0;0.01;0.01;0.01;0.01];
20 % alg_param=[0.01;0.01;0.0;0.0;0.01;0.01;0.01;0.01;100];
21 %
22 % [test_targets, a] = Perceptron(train_patterns, train_targets, test_patterns, alg_param)
23 % Plotpv(train_patterns,train_targets); %绘点,绘制分类模式
24 % a=a';
25 % plotpc(a(1:end-1),a(end:end)); %绘分割线;绘制决策面
26
27 [c, r] = size(train_patterns);
28 %Weighted Perceptron or not?
29 switch length(alg_param),
30 case r + 1,
31 %Ada boost form
32 p = alg_param(1:end-1);
33 max_iter = alg_param(end);
34 case {r,0},
35 %No parameter given
36 p = ones(1,r);
37 max_iter = 5000;
38 otherwise
39 %Number of iterations given
40 max_iter = alg_param;
41 p = ones(1,r);
42 end
43 train_patterns = [train_patterns ; ones(1,r)];
44 train_zero = find(train_targets == 0);
45 %Preprocessing
46 y = train_patterns;
47 y(:,train_zero)= -y(:,train_zero);
48 %Initial weights
49 a = sum(y')';
50 n = length(train_targets);
51 iter = 0;
52 while ((sum(a'*train_patterns.*(2*train_targets-1)<0)>0) && (iter < max_iter))
53 iter = iter + 1;
54 indice = 1 + floor(rand(1)*n);
55 if (a' * y(:,indice) <= 0)
56 a = a + p(indice)* y(:,indice);
57 end
58 end
59 if (iter == max_iter)&&(length(alg_param)~= r + 1),
60 disp(['Maximum iteration (' num2str(max_iter) ') reached']);
61 end
62 %Classify test patterns
63 test_targets = a'*[test_patterns; ones(1, size(test_patterns,2))] > 0;
64
65
2
3 % Classify using the Perceptron algorithm (Fixed increment single-sample perceptron)
4 % Inputs:
5 % train_patterns - Train patterns
6 % train_targets - Train targets
7 % test_patterns - Test patterns
8 % alg_param - Either: Number of iterations, weights vector or [weights, number of iterations]
9 % Outputs
10 % test_targets - Predicted targets
11 % a - Perceptron weights
12 % NOTE: Works for only two classes.
13 % 测试用法
14 % train_patterns=[-0.5,-0.5,0.3,0.1,-0.1,0.8,0.2,0.3;
15 % 0.3,-0.2,-0.6,0.1,-0.5,1.0,0.3,0.9];
16 % train_targets= [0,0,0,1,0,1,1,1];
17 % test_patterns=[0.2 -0.3]' %输出0
18 % % alg_param=100;
19 % alg_param=[0.01;0.01;0.0;0.0;0.01;0.01;0.01;0.01];
20 % alg_param=[0.01;0.01;0.0;0.0;0.01;0.01;0.01;0.01;100];
21 %
22 % [test_targets, a] = Perceptron(train_patterns, train_targets, test_patterns, alg_param)
23 % Plotpv(train_patterns,train_targets); %绘点,绘制分类模式
24 % a=a';
25 % plotpc(a(1:end-1),a(end:end)); %绘分割线;绘制决策面
26
27 [c, r] = size(train_patterns);
28 %Weighted Perceptron or not?
29 switch length(alg_param),
30 case r + 1,
31 %Ada boost form
32 p = alg_param(1:end-1);
33 max_iter = alg_param(end);
34 case {r,0},
35 %No parameter given
36 p = ones(1,r);
37 max_iter = 5000;
38 otherwise
39 %Number of iterations given
40 max_iter = alg_param;
41 p = ones(1,r);
42 end
43 train_patterns = [train_patterns ; ones(1,r)];
44 train_zero = find(train_targets == 0);
45 %Preprocessing
46 y = train_patterns;
47 y(:,train_zero)= -y(:,train_zero);
48 %Initial weights
49 a = sum(y')';
50 n = length(train_targets);
51 iter = 0;
52 while ((sum(a'*train_patterns.*(2*train_targets-1)<0)>0) && (iter < max_iter))
53 iter = iter + 1;
54 indice = 1 + floor(rand(1)*n);
55 if (a' * y(:,indice) <= 0)
56 a = a + p(indice)* y(:,indice);
57 end
58 end
59 if (iter == max_iter)&&(length(alg_param)~= r + 1),
60 disp(['Maximum iteration (' num2str(max_iter) ') reached']);
61 end
62 %Classify test patterns
63 test_targets = a'*[test_patterns; ones(1, size(test_patterns,2))] > 0;
64
65