SMO序列最小最优化算法
SMO例子:
1 from numpy import *
2 import matplotlib
3 import matplotlib.pyplot as plt
4
5 def loadDataSet(fileName):
6 dataMat = []; labelMat = []
7 fr = open(fileName)
8 for line in fr.readlines():
9 lineArr = line.strip().split(',')
10 dataMat.append([float(lineArr[0]), float(lineArr[1])])
11 labelMat.append(float(lineArr[2]))
12 return dataMat, labelMat
13
14 def selectJrand(i, m):
15 j = i
16 while (j == i):
17 j = int(random.uniform(0, m))
18 return j
19
20 def clipAlpha(aj, H, L):
21 if aj > H:
22 aj = H
23 if L > aj:
24 aj = L
25 return aj
26
27
28 def smoSimple(dataMatIn, classLabels, C, toler, maxIter):
29 dataMatrix = mat(dataMatIn); labelMat = mat(classLabels).transpose()
30 b = 0; m, n = shape(dataMatrix)
31 alphas = mat(zeros((m, 1)))
32 iter = 0
33 while (iter < maxIter):
34 alphaPairsChanged = 0 #用于记录alpha是否已经进行优化
35 for i in range(m):
36 fXi = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[i,:].T))+b # 预测的类别
37 Ei = fXi - float(labelMat[i]) #实际结果与真实结果的误差,如果误差很大,那么就要对该数据实例所对应的alpha值进行优化
38 if ((labelMat[i]*Ei < -toler) and (alphas[i] < C)) or ((labelMat[i]*Ei > toler) and (alphas[i]>0)):
39 j = selectJrand(i, m)
40 fXj = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[j, :].T))+b
41 Ej = fXj - float(labelMat[j])
42 alphaIold = alphas[i].copy()
43 alphaJold = alphas[j].copy()
44 if (labelMat[i] != labelMat[j]):
45 L = max(0, alphas[j] - alphas[i])
46 H = min(C, C+alphas[j]-alphas[i])
47 else:
48 L = max(0, alphas[j]+alphas[i]-C)
49 H = min(C, alphas[j]+alphas[i])
50 if L == H: print("L == H"); continue
51 eta = 2.0 * dataMatrix[i, :]*dataMatrix[j,:].T-dataMatrix[i,:]*dataMatrix[i,:].T-dataMatrix[j,:]*dataMatrix[j,:].T
52 if eta >= 0: print("eta >= 0"); continue
53 alphas[j] -= labelMat[j]*(Ei-Ej)/eta
54 alphas[j] = clipAlpha(alphas[j], H, L)
55 if (abs(alphas[j] - alphaJold) < 0.00001): print("j not moving enough"); continue
56 alphas[i] += labelMat[j] * labelMat[i] * (alphaJold-alphas[j])
57 b1 = b - Ei - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[i,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[i,:]*dataMatrix[j,:].T
58 b2 = b - Ej - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[j,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[j,:]*dataMatrix[j,:].T
59 if (0 < alphas[i]) and (C > alphas[i]): b = b1
60 elif (0 < alphas[j]) and (C > alphas[j]): b = b2
61 else: b = (b1+b2)/2.0
62 alphaPairsChanged += 1
63 print("iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))
64 if (alphaPairsChanged == 0): iter += 1
65 else: iter = 0
66 print("iteration number: %d" % iter)
67 return b, alphas
68
69 def draw(alpha, bet, data, label):
70 plt.xlabel(u"x1")
71 plt.xlim(0, 100)
72 plt.ylabel(u"x2")
73 for i in range(len(label)):
74 if label[i] > 0:
75 plt.plot(data[i][0], data[i][1], 'or')
76 else:
77 plt.plot(data[i][0], data[i][1], 'og')
78 w1 = 0.0
79 w2 = 0.0
80 for i in range(len(label)):
81 w1 += alpha[i] * label[i] * data[i][0]
82 w2 += alpha[i] * label[i] * data[i][1]
83 w = float(- w1 / w2)
84
85 b = float(- bet / w2)
86 r = float(1 / w2)
87 lp_x1 = list([10, 90])
88 lp_x2 = []
89 lp_x2up = []
90 lp_x2down = []
91 for x1 in lp_x1:
92 lp_x2.append(w * x1 + b)
93 lp_x2up.append(w * x1 + b + r)
94 lp_x2down.append(w * x1 + b - r)
95 lp_x2 = list(lp_x2)
96 lp_x2up = list(lp_x2up)
97 lp_x2down = list(lp_x2down)
98 plt.plot(lp_x1, lp_x2, 'b')
99 plt.plot(lp_x1, lp_x2up, 'b--')
100 plt.plot(lp_x1, lp_x2down, 'b--')
101 plt.show()
102
103
104
105 filestr = "E:\\Kaggle\\Digit Recognizer\\svmtest.txt"
106
107 dataArr, labelArr = loadDataSet(filestr)
108 print(dataArr)
109 print(labelArr)
110 b, alphas = smoSimple(dataArr, labelArr, 0.6, 0.001, 40);
111 print(b)
112 print(alphas)
113 draw(alphas, b, dataArr, labelArr)
下面是测试集
1 27,53,-1
2 49,37,-1
3 56,39,-1
4 28,60,-1
5 68,75,1
6 57,69,1
7 64,62,1
8 77,68,1
9 70,54,1
10 56,63,1
11 25,41,-1
12 66,34,1
13 55,79,1
14 77,31,-1
15 46,66,1
16 30,23,-1
17 21,45,-1
18 68,42,-1
19 43,43,-1
20 56,59,1
21 79,68,1
22 60,34,-1
23 49,32,-1
24 80,79,1
25 77,46,1
26 26,66,1
27 29,29,-1
28 77,34,1
29 20,71,-1
30 49,25,-1
31 58,65,1
32 33,57,-1
33 31,79,1
34 20,78,1
35 77,37,-1
36 73,34,-1
37 60,26,-1
38 77,66,1
39 71,75,1
40 35,36,-1
41 49,61,1
42 26,37,-1
43 42,73,1
44 36,50,-1
45 66,73,1
46 71,43,1
47 33,62,1
48 43,41,-1
49 42,29,-1
50 58,20,-1
下面是结果:
以上推导内容转自:http://liuhongjiang.github.io/tech/blog/2012/12/28/svm-smo/