利用感知机做线性分类小demo
如题啦,单个感知机分类
结果展示:
代码:
import numpy as np import pandas as pd import matplotlib.pyplot as plt import random def sgn(x): if x>=0: return 1; else: return -1; class point: x1=0 x2=0 y=1 def __init__(self,x1,x2,y): self.x1=x1 self.x2=x2 self.y=y #def __init__(para): # self.x1=para[0] # self.x2=para[1] # self.y=para[2] def cauy(self,para): return (self.x1*para[0]+self.x2*para[1]+para[2]) p1=point(1,1,-1) p2=point(3,3,1) p3=point(4,3,1) #画三个点 plt.scatter(1,1,c='r') plt.scatter(3,3,c='b') plt.scatter(4,3,c='b') #参数与训练集 w1=w2=b=0 n=1 trainset=[p1,p2,p3] parameter=[w1,w2,b] count=0 #循环训练参数 while(1): random.shuffle(trainset) for i in trainset: #预测值 y_=sgn(i.cauy(parameter)) #遇到误分类点 if(y_!=i.y): #参数修正 w1+=n*(i.y-y_)*i.x1 w2+=n*(i.y-y_)*i.x2 b+=n*i.y parameter=[w1,w2,b] #绘图 x0=np.array([1,2,3,4,5,6,7,8,9,10]) if w2!=0: plt.plot(x0,x0*w1/(-w2)-b/w2) elif w1!=0: plt.plot([-b/w1]*10,x0) else: plt.scatter(0,0,c="g") print(w1,'x1 + ',w2,'x2 + ',b) count+=1 break else: break print("you have try",count,"times") plt.show()
ps:下面还有随机产生数据集的版本(但是点多了就很难线性可分,一般5个随机点就很难分了)
import numpy as np import pandas as pd import matplotlib.pyplot as plt import random def sgn(x): if x>=0: return 1; else: return -1; class point: x1=0 x2=0 y=1 def __init__(self,x1,x2,y): self.x1=x1 self.x2=x2 self.y=y #def __init__(para): # self.x1=para[0] # self.x2=para[1] # self.y=para[2] def cauy(self,para): return (self.x1*para[0]+self.x2*para[1]+para[2]) #随机生成数据集 trainset=[point(int(random.randrange(1,10,1)),int(random.randrange(1,10,1)),sgn(random.randrange(-5,5,1))) for i in range(4)] #画点 for p in trainset: if p.y==1: plt.scatter(p.x1,p.x2,c='b') else: plt.scatter(p.x1,p.x2,c='r') #参数与训练集 w1=w2=b=0 n=1 parameter=[w1,w2,b] count=0 res=False #循环训练参数 while(1and count<20): #理论上不打乱也行 #random.shuffle(trainset) for i in trainset: #预测值 y_=sgn(i.cauy(parameter)) #遇到误分类点 if(y_!=i.y): #参数修正 w1+=n*(i.y-y_)*i.x1 w2+=n*(i.y-y_)*i.x2 b+=n*i.y parameter=[w1,w2,b] #绘图 x0=np.array([1,2,3,4,5,6,7,8,9,10]) if w2!=0: plt.plot(x0,x0*w1/(-w2)-b/w2) elif w1!=0: plt.plot([-b/w1]*10,x0) else: plt.scatter(0,0,c='g') print(w1,'x1 + ',w2,'x2 + ',b) count+=1 break else: res=True break if(res): print("you have try",count,"times and you are succeed") else : print("it is unclassifiable in 20 times training !!") plt.show()