AdaBoost算法

理论知识可参考:《统计学习方法》 (李航 著) 第八章

简单代码实现:

 1 from numpy import *
 2 import matplotlib.pyplot as plt
 3 
 4 def loadSimpData():
 5     dataMat = matrix([[1,2.1],
 6         [2,1.1],
 7         [1.3,1],
 8         [1,1],
 9         [2,1]])
10     classLabels = [1.0, 1.0, -1.0, -1.0, 1.0]
11     return dataMat, classLabels
12 
13 def stumpClassify(dataMatrix, dimen, threshVal, threshIneq):
14     retArray = ones((shape(dataMatrix)[0], 1))
15     if threshIneq == 'lt':
16         retArray[dataMatrix[:, dimen] <= threshVal] = -1.0
17     else:
18         retArray[dataMatrix[:, dimen] > threshVal] = -1.0
19     return retArray
20 
21 def buildStump(dataArr, classLabels, D):
22     dataMatrix = mat(dataArr); labelMat = mat(classLabels).T
23     m, n = shape(dataMatrix)
24     numSteps = 10.0; bestStump = {}; bestClasEst = mat(zeros((m, 1)))
25     minError = inf
26     for i in range(n):
27         rangeMin = dataMatrix[:, i].min(); rangeMax = dataMatrix[:,i].max();
28         stepSize = (rangeMax-rangeMin)/numSteps
29         for j in range(-1, int(numSteps)+1):
30             for inequal in ['lt', 'gt']:
31                 threshVal = (rangeMin + float(j) * stepSize)
32                 predictedVals = stumpClassify(dataMatrix, i, threshVal, inequal)
33                 errArr = mat(ones((m, 1)))
34                 errArr[predictedVals == labelMat] = 0
35                 weightedError = D.T * errArr
36                 #print("split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError))
37                 if weightedError < minError:
38                     minError = weightedError
39                     bestClasEst = predictedVals.copy()
40                     bestStump['dim'] = i
41                     bestStump['thresh'] = threshVal
42                     bestStump['ineq'] = inequal
43     return bestStump, minError, bestClasEst
44 
45 def adaBoostTrainDS(dataArr, classLabels, numIt = 40):
46     weakClassArr = []
47     m = shape(dataArr)[0]
48     D = mat(ones((m, 1))/m)
49     aggClassEst = mat(zeros((m, 1)))
50     for i in range(numIt):
51         bestStump, error, classEst = buildStump(dataArr, classLabels, D)
52         print("D:", D.T)
53         alpha = float(0.5*log((1.0-error)/max(error, 1e-16)))
54         bestStump['alpha'] = alpha
55         weakClassArr.append(bestStump)
56         print("classEst:", classEst)
57         expon = multiply(-1*alpha*mat(classLabels).T, classEst)
58         D = multiply(D, exp(expon))
59         D = D/D.sum()
60         aggClassEst += alpha*classEst
61         print("aggClassEst:", aggClassEst.T)
62         aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T, ones((m, 1)))
63         errorRate = aggErrors.sum()/m
64         print("total error:", errorRate, "\n")
65         if errorRate == 0.0: break
66     return weakClassArr
67 
68 dataMat, classLabels = loadSimpData()
69 D = mat(ones((5, 1))/5)
70 classifierArray = adaBoostTrainDS(dataMat, classLabels, 9)
71 print(classifierArray)
View Code

 

posted on 2016-03-17 20:24  JustForCS  阅读(231)  评论(0编辑  收藏  举报

导航