SMO序列最小最优化算法
SMO例子:
1 from numpy import * 2 import matplotlib 3 import matplotlib.pyplot as plt 4 5 def loadDataSet(fileName): 6 dataMat = []; labelMat = [] 7 fr = open(fileName) 8 for line in fr.readlines(): 9 lineArr = line.strip().split(',') 10 dataMat.append([float(lineArr[0]), float(lineArr[1])]) 11 labelMat.append(float(lineArr[2])) 12 return dataMat, labelMat 13 14 def selectJrand(i, m): 15 j = i 16 while (j == i): 17 j = int(random.uniform(0, m)) 18 return j 19 20 def clipAlpha(aj, H, L): 21 if aj > H: 22 aj = H 23 if L > aj: 24 aj = L 25 return aj 26 27 28 def smoSimple(dataMatIn, classLabels, C, toler, maxIter): 29 dataMatrix = mat(dataMatIn); labelMat = mat(classLabels).transpose() 30 b = 0; m, n = shape(dataMatrix) 31 alphas = mat(zeros((m, 1))) 32 iter = 0 33 while (iter < maxIter): 34 alphaPairsChanged = 0 #用于记录alpha是否已经进行优化 35 for i in range(m): 36 fXi = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[i,:].T))+b # 预测的类别 37 Ei = fXi - float(labelMat[i]) #实际结果与真实结果的误差,如果误差很大,那么就要对该数据实例所对应的alpha值进行优化 38 if ((labelMat[i]*Ei < -toler) and (alphas[i] < C)) or ((labelMat[i]*Ei > toler) and (alphas[i]>0)): 39 j = selectJrand(i, m) 40 fXj = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[j, :].T))+b 41 Ej = fXj - float(labelMat[j]) 42 alphaIold = alphas[i].copy() 43 alphaJold = alphas[j].copy() 44 if (labelMat[i] != labelMat[j]): 45 L = max(0, alphas[j] - alphas[i]) 46 H = min(C, C+alphas[j]-alphas[i]) 47 else: 48 L = max(0, alphas[j]+alphas[i]-C) 49 H = min(C, alphas[j]+alphas[i]) 50 if L == H: print("L == H"); continue 51 eta = 2.0 * dataMatrix[i, :]*dataMatrix[j,:].T-dataMatrix[i,:]*dataMatrix[i,:].T-dataMatrix[j,:]*dataMatrix[j,:].T 52 if eta >= 0: print("eta >= 0"); continue 53 alphas[j] -= labelMat[j]*(Ei-Ej)/eta 54 alphas[j] = clipAlpha(alphas[j], H, L) 55 if (abs(alphas[j] - alphaJold) < 0.00001): print("j not moving enough"); continue 56 alphas[i] += labelMat[j] * labelMat[i] * (alphaJold-alphas[j]) 57 b1 = b - Ei - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[i,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[i,:]*dataMatrix[j,:].T 58 b2 = b - Ej - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[j,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[j,:]*dataMatrix[j,:].T 59 if (0 < alphas[i]) and (C > alphas[i]): b = b1 60 elif (0 < alphas[j]) and (C > alphas[j]): b = b2 61 else: b = (b1+b2)/2.0 62 alphaPairsChanged += 1 63 print("iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged)) 64 if (alphaPairsChanged == 0): iter += 1 65 else: iter = 0 66 print("iteration number: %d" % iter) 67 return b, alphas 68 69 def draw(alpha, bet, data, label): 70 plt.xlabel(u"x1") 71 plt.xlim(0, 100) 72 plt.ylabel(u"x2") 73 for i in range(len(label)): 74 if label[i] > 0: 75 plt.plot(data[i][0], data[i][1], 'or') 76 else: 77 plt.plot(data[i][0], data[i][1], 'og') 78 w1 = 0.0 79 w2 = 0.0 80 for i in range(len(label)): 81 w1 += alpha[i] * label[i] * data[i][0] 82 w2 += alpha[i] * label[i] * data[i][1] 83 w = float(- w1 / w2) 84 85 b = float(- bet / w2) 86 r = float(1 / w2) 87 lp_x1 = list([10, 90]) 88 lp_x2 = [] 89 lp_x2up = [] 90 lp_x2down = [] 91 for x1 in lp_x1: 92 lp_x2.append(w * x1 + b) 93 lp_x2up.append(w * x1 + b + r) 94 lp_x2down.append(w * x1 + b - r) 95 lp_x2 = list(lp_x2) 96 lp_x2up = list(lp_x2up) 97 lp_x2down = list(lp_x2down) 98 plt.plot(lp_x1, lp_x2, 'b') 99 plt.plot(lp_x1, lp_x2up, 'b--') 100 plt.plot(lp_x1, lp_x2down, 'b--') 101 plt.show() 102 103 104 105 filestr = "E:\\Kaggle\\Digit Recognizer\\svmtest.txt" 106 107 dataArr, labelArr = loadDataSet(filestr) 108 print(dataArr) 109 print(labelArr) 110 b, alphas = smoSimple(dataArr, labelArr, 0.6, 0.001, 40); 111 print(b) 112 print(alphas) 113 draw(alphas, b, dataArr, labelArr)
下面是测试集
1 27,53,-1 2 49,37,-1 3 56,39,-1 4 28,60,-1 5 68,75,1 6 57,69,1 7 64,62,1 8 77,68,1 9 70,54,1 10 56,63,1 11 25,41,-1 12 66,34,1 13 55,79,1 14 77,31,-1 15 46,66,1 16 30,23,-1 17 21,45,-1 18 68,42,-1 19 43,43,-1 20 56,59,1 21 79,68,1 22 60,34,-1 23 49,32,-1 24 80,79,1 25 77,46,1 26 26,66,1 27 29,29,-1 28 77,34,1 29 20,71,-1 30 49,25,-1 31 58,65,1 32 33,57,-1 33 31,79,1 34 20,78,1 35 77,37,-1 36 73,34,-1 37 60,26,-1 38 77,66,1 39 71,75,1 40 35,36,-1 41 49,61,1 42 26,37,-1 43 42,73,1 44 36,50,-1 45 66,73,1 46 71,43,1 47 33,62,1 48 43,41,-1 49 42,29,-1 50 58,20,-1
下面是结果:
以上推导内容转自:http://liuhongjiang.github.io/tech/blog/2012/12/28/svm-smo/
posted on 2016-03-16 14:37 JustForCS 阅读(2649) 评论(0) 编辑 收藏 举报