SMO序列最小最优化算法

SMO例子:

复制代码
  1 from numpy import *
  2 import matplotlib
  3 import matplotlib.pyplot as plt
  4 
  5 def loadDataSet(fileName):
  6     dataMat = []; labelMat = []
  7     fr = open(fileName)
  8     for line in fr.readlines():
  9         lineArr = line.strip().split(',')
 10         dataMat.append([float(lineArr[0]), float(lineArr[1])])
 11         labelMat.append(float(lineArr[2]))
 12     return dataMat, labelMat
 13     
 14 def selectJrand(i, m):
 15     j = i
 16     while (j == i):
 17         j = int(random.uniform(0, m))
 18     return j
 19 
 20 def clipAlpha(aj, H, L):
 21     if aj > H:
 22         aj = H
 23     if L > aj:
 24         aj = L
 25     return aj
 26 
 27 
 28 def smoSimple(dataMatIn, classLabels, C, toler, maxIter):
 29     dataMatrix = mat(dataMatIn); labelMat = mat(classLabels).transpose()
 30     b = 0; m, n = shape(dataMatrix)
 31     alphas = mat(zeros((m, 1)))
 32     iter = 0
 33     while (iter < maxIter):
 34         alphaPairsChanged = 0 #用于记录alpha是否已经进行优化
 35         for i in range(m):
 36             fXi = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[i,:].T))+b # 预测的类别
 37             Ei = fXi - float(labelMat[i]) #实际结果与真实结果的误差,如果误差很大,那么就要对该数据实例所对应的alpha值进行优化
 38             if ((labelMat[i]*Ei < -toler) and (alphas[i] < C)) or ((labelMat[i]*Ei > toler) and (alphas[i]>0)):
 39                 j = selectJrand(i, m)
 40                 fXj = float(multiply(alphas, labelMat).T*(dataMatrix*dataMatrix[j, :].T))+b
 41                 Ej = fXj - float(labelMat[j])
 42                 alphaIold = alphas[i].copy()
 43                 alphaJold = alphas[j].copy()
 44                 if (labelMat[i] != labelMat[j]):
 45                     L = max(0, alphas[j] - alphas[i])
 46                     H = min(C, C+alphas[j]-alphas[i])
 47                 else:
 48                     L = max(0, alphas[j]+alphas[i]-C)
 49                     H = min(C, alphas[j]+alphas[i])
 50                 if L == H: print("L == H"); continue
 51                 eta = 2.0 * dataMatrix[i, :]*dataMatrix[j,:].T-dataMatrix[i,:]*dataMatrix[i,:].T-dataMatrix[j,:]*dataMatrix[j,:].T
 52                 if eta >= 0: print("eta >= 0"); continue
 53                 alphas[j] -= labelMat[j]*(Ei-Ej)/eta
 54                 alphas[j] = clipAlpha(alphas[j], H, L)
 55                 if (abs(alphas[j] - alphaJold) < 0.00001): print("j not moving enough"); continue
 56                 alphas[i] += labelMat[j] * labelMat[i] * (alphaJold-alphas[j])
 57                 b1 = b - Ei - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[i,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[i,:]*dataMatrix[j,:].T
 58                 b2 = b - Ej - labelMat[i]*(alphas[i]-alphaIold)*dataMatrix[i,:]*dataMatrix[j,:].T-labelMat[j]*(alphas[j]-alphaJold)*dataMatrix[j,:]*dataMatrix[j,:].T
 59                 if (0 < alphas[i]) and (C > alphas[i]): b = b1
 60                 elif (0 < alphas[j]) and (C > alphas[j]): b = b2
 61                 else: b = (b1+b2)/2.0
 62                 alphaPairsChanged += 1
 63                 print("iter: %d  i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))
 64         if (alphaPairsChanged == 0): iter += 1
 65         else: iter = 0
 66         print("iteration number: %d" % iter)
 67     return b, alphas
 68     
 69 def draw(alpha, bet, data, label):
 70     plt.xlabel(u"x1")
 71     plt.xlim(0, 100)
 72     plt.ylabel(u"x2")
 73     for i in range(len(label)):
 74         if label[i] > 0:
 75             plt.plot(data[i][0], data[i][1], 'or')
 76         else:
 77             plt.plot(data[i][0], data[i][1], 'og')
 78     w1 = 0.0
 79     w2 = 0.0
 80     for i in range(len(label)):
 81         w1 += alpha[i] * label[i] * data[i][0]
 82         w2 += alpha[i] * label[i] * data[i][1]
 83     w = float(- w1 / w2)
 84     
 85     b = float(- bet / w2)
 86     r = float(1 / w2)
 87     lp_x1 = list([10, 90])
 88     lp_x2 = []
 89     lp_x2up = []
 90     lp_x2down = []
 91     for x1 in lp_x1:
 92         lp_x2.append(w * x1 + b)
 93         lp_x2up.append(w * x1 + b + r)
 94         lp_x2down.append(w * x1 + b - r)
 95     lp_x2 = list(lp_x2)
 96     lp_x2up = list(lp_x2up)
 97     lp_x2down = list(lp_x2down)
 98     plt.plot(lp_x1, lp_x2, 'b')
 99     plt.plot(lp_x1, lp_x2up, 'b--')
100     plt.plot(lp_x1, lp_x2down, 'b--')
101     plt.show()
102     
103 
104     
105 filestr = "E:\\Kaggle\\Digit Recognizer\\svmtest.txt"
106 
107 dataArr, labelArr = loadDataSet(filestr)
108 print(dataArr)
109 print(labelArr)
110 b, alphas = smoSimple(dataArr, labelArr, 0.6, 0.001, 40);
111 print(b)
112 print(alphas)
113 draw(alphas, b, dataArr, labelArr)
复制代码

下面是测试集

复制代码
 1 27,53,-1
 2 49,37,-1
 3 56,39,-1
 4 28,60,-1
 5 68,75,1
 6 57,69,1
 7 64,62,1
 8 77,68,1
 9 70,54,1
10 56,63,1
11 25,41,-1
12 66,34,1
13 55,79,1
14 77,31,-1
15 46,66,1
16 30,23,-1
17 21,45,-1
18 68,42,-1
19 43,43,-1
20 56,59,1
21 79,68,1
22 60,34,-1
23 49,32,-1
24 80,79,1
25 77,46,1
26 26,66,1
27 29,29,-1
28 77,34,1
29 20,71,-1
30 49,25,-1
31 58,65,1
32 33,57,-1
33 31,79,1
34 20,78,1
35 77,37,-1
36 73,34,-1
37 60,26,-1
38 77,66,1
39 71,75,1
40 35,36,-1
41 49,61,1
42 26,37,-1
43 42,73,1
44 36,50,-1
45 66,73,1
46 71,43,1
47 33,62,1
48 43,41,-1
49 42,29,-1
50 58,20,-1
复制代码

下面是结果:

以上推导内容转自:http://liuhongjiang.github.io/tech/blog/2012/12/28/svm-smo/

posted on 2016-06-10 23:06  _harvey  阅读(1105)  评论(0编辑  收藏  举报

导航