单层感知器 - 坐标点二分类问题
单层感知器是神经网络的入门常识,基本的单层感知器可以解决线性分类问题。这里我们通过实例体验感知器是如何运作的。本次实例参照教材《MATLAB神经网络原理与实例精解》。
单层感知器的基本结构
如图,单层感知器可以有多个输入,它们通过与权值相乘,再相加(即加权求和)后,经过一定的偏置,再由激活函数处理,最后输出得到预测结果。这里面存在两种变化:线性变化与非线性变化。其中,加权求和属于线性变化,激活函数做的是非线性变化。通过上述两种变化 ,可以把输入的数据空间扭曲,使得只需要一个超平面就可以将其分开(线性可分),从而达到分类的目的。
单层感知器的工作原理
与其他的优化算法一样,感知器做的工作就是不断的调整权值,使得输入的数据空间扭曲到适当的程度,然后再利用超平面一刀切开,达到二分类的效果。所有的算法都会有一个迭代终止指标,对于单层感知器来说,当输出的预测值与期望值之间的误差达到一定的精度要求,或者迭代次数超过一定的次数时(计算机也不可以无限的运行下去),算法结束。
单层感知器解决坐标的二分类问题
我们给出6个点的坐标,并给每个点的坐标设置分类,标签为0(第一类)和1(第二类)。利用单层感知器,找到一个超平面(就是一根直线)将两类坐标分开(即两类坐标分别处在直线的两边)。
代码实现
Python代码
import numpy as np
import matplotlib.pyplot as plt
# 参数初始化
n = 0.2 # 学习率
w = np.array([0, 0, 0]) # 权值
p = np.array([[-9, 1, -12, -4, 0, 5],
[15, -8, 4, 5, 11, 9]]) # 坐标
d = np.array([0, 1, 0, 0, 0, 1]) # 坐标分类标签
P = np.vstack((np.ones((1, 6)), p)) # 输入矩阵
MAX = 20 # 最大迭代次数
ee = [] # 误差
i = 0 # 记录迭代次数
# 定义激活函数
def hardlim(a):
for i in range(len(a)):
if a[i] >= 0:
a[i] = 1
else:
a[i] = 0
return a
# 定义平均绝对误差
def mae(a):
return sum(abs(a))/len(a)
while 1:
v = np.matmul(w, P)
y = hardlim(v) # 实际输出
# 更新
e = (d - y)
ee.append(mae(e))
if ee[i] < 0.001:
print('we have got it:')
print(w)
break
w = w + n*np.matmul(d-y,P.T)
i = i + 1
if i >= MAX:
print('MAX times loop')
print(w)
print(ee[i])
break
# 画图
plt.figure()
plt.rcParams['font.sans-serif'] = ['Simhei']
plt.rcParams['axes.unicode_minus'] = False
plt.subplot(211)
plt.xlim(-13, 6)
plt.ylim(-10, 16)
plt.plot([-9, -12, -4, 0], [15, 4, 5, 11], 'o', label='第一类')
plt.plot([1, 5], [-8, 9], '*', label='第二类')
plt.legend(loc='lower right')
plt.title('6个坐标点的二分类')
x = np.arange(-13, 6, 0.2)
y = x * (-w[1]/w[2]) - w[0]/w[2]
plt.plot(x, y)
plt.subplot(212)
x = np.arange(0,len(ee))
plt.plot(x, ee, 'o-')
plt.title('mae的值(迭代次数:%.0f)'%len(ee))
plt.subplots_adjust(wspace =0, hspace =0.5)
plt.show()
输出画面
Matlab代码
% perception_hand.m
%% 清理
clear,clc
close all
%%
n=0.2; % 学习率
w=[0,0,0];
P=[ -9, 1, -12, -4, 0, 5;...
15, -8, 4, 5, 11, 9];
d=[0,1,0,0,0,1]; % 期望输出
P=[ones(1,6);P];
MAX=20; % 最大迭代次数为20次
%% 训练
i=0;
while 1
v=w*P;
y=hardlim(v); % 实际输出
%更新
e=(d-y);
ee(i+1)=mae(e);
if (ee(i+1)<0.001) % 判断
disp('we have got it:');
disp(w);
break;
end
% 更新权值和偏置
w=w+n*(d-y)*P';
if (i>=MAX) % 达到最大迭代次数,退出
disp('MAX times loop');
disp(w);
disp(ee(i+1));
break;
end
i= i+1;
end
%% 显示
figure;
subplot(2,1,1); % 显示待分类的点和分类结果
plot([-9 , -12 -4 0],[15, 4 5 11],'o');
hold on;
plot([1,5],[-8,9],'*');
axis([-13,6,-10,16]);
legend('第一类','第二类');
title('6个坐标点的二分类');
x=-13:.2:6;
y=x*(-w(2)/w(3))-w(1)/w(3);
plot(x,y);
hold off;
subplot(2,1,2); % 显示mae值的变化
x=0:i;
plot(x,ee,'o-');
s=sprintf('mae的值(迭代次数:%d)', i+1);
title(s);
输出画面
未经作者授权,禁止转载
THE END