感知机(perceptron) 学习笔记
前言:偶尔回想自己学过的算法,反复看反复忘,故又重复看一遍并记下笔记,供后续学习参考。
感知机是一个二分类算法,是深度学习的简化版,只有一层网络,建模思想跟支持向量机类似,是两算法的基础。
分类原理 :y =(编辑公式还有更好的办法吗?), 满足wx+b>=0的输入被分类为标签1,否则被分类为标签-1。
建模:分类的是超平面wx+b=0,以输入点到超平面的距离为判断标准。距离公式为d = |wx+b|/ ||w|| = y(wx+b)/||w||。||.||是二范数。
损失函数:损失的自然选择是错误分类的总点数,但是这样的损失不是参数w,b的连续可到函数,不易优化,所以将错误分类的总点数到超平面的距离的总距离定义为损失函数,因此损失函数deltaL = -1*y(wx+b)/||w||,此处根据SVM中几何间隔,知需要对w进行约束,避免b与w同比例增加,参数变了但是超平面本身并没有变动,所以约定||w||=1,损失函数为deltaL = -1*y(wx+b),更细致的推理可参见其他资料,(我用||w||计算损失进行训练没见特别冥想的错误?)。
优化算法:优化算法是为了求得最优解,这里使用随机梯度下降算法SGD,已知损失函数,根据其对变量w,b求导得到,dw = -yx, db = -y,设学习率为u = 0.5。w= w-dw = w+uyx, b = b - db = b+uy
代码如下:
import math import random ''' 感知机是二分类模型 判断条件 y = 1 when w1x1+w2x2+b>0 y = -1 when w1x1+w2x2+b<0 ''' #z准备数据集 X = [[3,2], [12,10], [33,62], [8,16], [23,45], [7,13], [78,65], [35,54], [77,55], [89,23]] Y = [-1, -1, 1, -1, -1, -1,1, -1, 1, 1] #训练网络/推理逻辑 # y= x1*w1 + x2*w2 +b w1 = 0.1 w2 = 0.1 b = 0 u = 0.5 # for [x1,x2] in X: # y = x1*w1 + x2*w2 +b #计算损失函数 #点到平面wx+b=0 的距离 d = |wx+b| / ||w||,对于误分类的则需要被统计作为损失,目标是使损失变为0 #d = -yi*(wxi+b)/math.sqrt(sum(pow(w,2))) #优化算法 梯度函数 #dw = yixi ; db = yi, u是步长(学习率) #w = w+ udw #b = b+ udb #每个epoch遇到误分类点后更新一次变量 deltal = 0 for i in range(10000): print(i) deltal = 0 for i in range(20): index = random.randint(0,len(X)-1) # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]) # print((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1)) if((X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]<=0): print("error") w1 = w1+ u*Y[index]*X[index][0] w2 = w2 + u*Y[index]*X[index][1] b = b + u*Y[index] # print(math.sqrt(pow(w1,2)+ pow(w2,2))) deltal =deltal + (X[index][0]*w1 + X[index][1]*w2 +b)* Y[index]*(-1)/math.sqrt(pow(w1,2)+ pow(w2,2)) # print(deltal) print('损失函数是:{}, w1 is:{}, w2 is:{}, b is:{}'.format(deltal, w1, w2, b)) if(deltal==0): break
得到的结果如下:
0 error error error error error error error 损失函数是:-58.3676307788504, w1 is:9.600000000000001, w2 is:-11.899999999999999, b is:-2.5 1 error error error error error error error 损失函数是:-79.6906079022878, w1 is:16.1, w2 is:-27.4, b is:-5.0 2 error error error error error error error error error error error 损失函数是:-181.72511941015193, w1 is:19.1, w2 is:-35.4, b is:-7.5 3 error error error error error error error error error 损失函数是:-134.1050074696168, w1 is:52.1, w2 is:-21.4, b is:-9.0 4 error error error error error error error error error error error 损失函数是:-81.79882696995062, w1 is:21.1, w2 is:-49.9, b is:-12.5 5 error error error error error error error error error error error 损失函数是:-119.81313066808934, w1 is:23.6, w2 is:-19.9, b is:-14.0 6 error error error error 损失函数是:-53.31304070579386, w1 is:21.1, w2 is:-17.4, b is:-15.0 7 error error error error error 损失函数是:-41.93949987798639, w1 is:31.6, w2 is:9.600000000000001, b is:-16.5 8 error error error 损失函数是:-5.054308599796722, w1 is:17.1, w2 is:-14.899999999999999, b is:-18.0 9 error error 损失函数是:-5.061614611137055, w1 is:9.600000000000001, w2 is:-20.9, b is:-19.0 10 error error error error error error error error error error error 损失函数是:-176.0012076253418, w1 is:30.1, w2 is:6.100000000000001, b is:-21.5 11 error error error error error error error error error error 损失函数是:-125.48127046860397, w1 is:37.1, w2 is:-13.899999999999999, b is:-24.5 12 error error error error error error error error error error error 损失函数是:-125.1563990295305, w1 is:7.100000000000001, w2 is:-41.9, b is:-28.0 13 error error error error error error error error 损失函数是:-99.9588692981525, w1 is:31.6, w2 is:-34.4, b is:-30.0 14 error error error error error error error error 损失函数是:-123.83992037821682, w1 is:39.1, w2 is:-43.4, b is:-32.0 15 error error error error error error error error error error 损失函数是:-107.52054118479901, w1 is:38.099999999999994, w2 is:-61.4, b is:-35.0 16 error error error error error error 损失函数是:-12.707209628891478, w1 is:53.599999999999994, w2 is:-18.4, b is:-36.0 17 error error error error error error error 损失函数是:-23.948721417884677, w1 is:40.099999999999994, w2 is:-28.4, b is:-38.5 18 error error error error error error error error 损失函数是:-139.31131437239534, w1 is:31.099999999999994, w2 is:-56.4, b is:-40.5 19 error error error error error error error 损失函数是:-95.98994407835531, w1 is:48.099999999999994, w2 is:-62.9, b is:-42.0 20 error error error error error error error error error 损失函数是:-92.55322975282904, w1 is:78.6, w2 is:-39.9, b is:-43.5 21 error error 损失函数是:-17.88492872798486, w1 is:55.099999999999994, w2 is:-71.9, b is:-44.5 22 error error 损失函数是:-18.30431876055715, w1 is:54.099999999999994, w2 is:-67.9, b is:-44.5 23 error error 损失函数是:-19.673217617324354, w1 is:53.099999999999994, w2 is:-63.900000000000006, b is:-44.5 24 error error 损失函数是:-21.10211583473629, w1 is:52.099999999999994, w2 is:-59.900000000000006, b is:-44.5 25 error error error error error error error error 损失函数是:-65.34583910969604, w1 is:66.6, w2 is:-47.400000000000006, b is:-46.5 26 error error error error error error error error error 损失函数是:-74.24796184936555, w1 is:81.6, w2 is:-42.400000000000006, b is:-49.0 27 error error error 损失函数是:-34.124857156842516, w1 is:63.099999999999994, w2 is:-65.4, b is:-49.5 28 error error error error error error error error error error 损失函数是:-20.592229958037294, w1 is:52.099999999999994, w2 is:-56.400000000000006, b is:-52.5 29 error error error error error error error error 损失函数是:-15.858697189511131, w1 is:58.099999999999994, w2 is:-24.900000000000006, b is:-54.5 30 error error error error error 损失函数是:-59.07572886107646, w1 is:52.099999999999994, w2 is:-24.900000000000006, b is:-55.0 31 error error error error error 损失函数是:-33.3306474537016, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5 32 损失函数是:0, w1 is:35.099999999999994, w2 is:-38.900000000000006, b is:-56.5 Process finished with exit code 0
重复运行多次,会有不同的结果。感知机算法由于采用不同的初值和选取不同的误分类点,解可以不同。
以上为个人理解,如有不对的地方,欢迎交流指正~