目前学了几个ML的分类的经典算法,但是一直想着是否有一种能将这些算法集成起来的,今天看到了AdaBoost,也算是半个集成,感觉这个思路挺好,很像人的训练过程,并且对决策树是一个很好的补充,因为决策树容易过拟合,用AdaBoost可以让一棵很深的决策树将其分开成多棵矮树,后来发现原来这个想法和random forest比较相似,RF的代码等下周有空的时候可以写一下。
这个貌似挺厉害的,看那些专门搞学术的人说是一篇很牛逼的论文证明说可以把弱学习提升到强学习。我这种搞工程的,能知道他的原理,适用范围,能自己写一遍代码,感觉还是比那些读几遍论文只能惶惶其谈的要安心些。
关于AdaBoost的基本概念,通过《机器学习方法》来概要的说下。
bagging和boosting的区别
bagging:是指在原始数据上通过放回抽样,抽出和原始数据大小相等的新数据集(这个性质说明新数据集存在重复的值,而原始数据部分数据值不会出现在新数据集中),并重复该过程选择N个新数据集,这样通过N个分类器对这个N个数据集进行分类,最后选择分类器投票结果中最多类别作为最后的分类结果。
boosting:相比bagging,boosting像是一种串行,bagging是一种并行的,bagging可以对于N个数据集通过N个分类器同时进行分类,并且每个分类器的权重是一样的,但是boosting则相反,boosting是利用一个数据集依次由每个分类器进行分类,而确定每个分类器的权重是加大正确率高的分类器的权重,减少正确率低的分类器的权重。同时为了提高准确率,每次会降低被正确分类的样本的权重,提高没有正确分类的样本的权重。这样做其实比较符合人的决策过程,就是要多训练自己容易做错的题型,并且要多听取正确性高的老师的意见。
那么AdaBoost的主要的两个过程就是提高错误分类的样本权重和提高正确率高的分类器的权重。
算法的步骤:
输入:训练集T,弱学习分类器(这里是一个节点的决策树)
输出:最终的分类器G
1 先初始化样本权重值,D1={W11...W1n}W1i=1/n
2 根据样本权重D1以及决策树求分类误差率,并求的最小的误差率em,以及该决策树
em=
3 计算该分类器的权重
可以看出,误差率越小的,其权重越大
4 更新各个样本的权重,Dm+1,(用公式编辑器好麻烦。。。 )
其中Zm是规范化银子:
5 构建基本分类器
F(X)=
6 计算该分类器下的误差率,如果小于某个阈值就停止,否则从第二步开始迭代
终于不用打公式了。。。。
附上代码:
1 import java.io.BufferedReader; 2 import java.io.FileInputStream; 3 import java.io.IOException; 4 import java.io.InputStreamReader; 5 import java.util.ArrayList; 6 7 class Stump{ 8 public int dim; 9 public double thresh; 10 public String condition; 11 public double error; 12 public ArrayList<Integer> labelList; 13 double factor; 14 15 public String toString(){ 16 return "dim is "+dim+"\nthresh is "+thresh+"\ncondition is "+condition+"\nerror is "+error+"\nfactor is "+factor+"\nlabel is "+labelList; 17 } 18 } 19 20 class Utils{ 21 //加载数据集 22 public static ArrayList<ArrayList<Double>> loadDataSet(String filename) throws IOException{ 23 ArrayList<ArrayList<Double>> dataSet=new ArrayList<ArrayList<Double>>(); 24 FileInputStream fis=new FileInputStream(filename); 25 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 26 BufferedReader br=new BufferedReader(isr); 27 String line=""; 28 29 while((line=br.readLine())!=null){ 30 ArrayList<Double> data=new ArrayList<Double>(); 31 String[] s=line.split(" "); 32 33 for(int i=0;i<s.length-1;i++){ 34 data.add(Double.parseDouble(s[i])); 35 } 36 dataSet.add(data); 37 } 38 return dataSet; 39 } 40 41 //加载类别 42 public static ArrayList<Integer> loadLabelSet(String filename) throws NumberFormatException, IOException{ 43 ArrayList<Integer> labelSet=new ArrayList<Integer>(); 44 45 FileInputStream fis=new FileInputStream(filename); 46 InputStreamReader isr=new InputStreamReader(fis,"UTF-8"); 47 BufferedReader br=new BufferedReader(isr); 48 String line=""; 49 50 while((line=br.readLine())!=null){ 51 String[] s=line.split(" "); 52 labelSet.add(Integer.parseInt(s[s.length-1])); 53 } 54 return labelSet; 55 } 56 //测试用的 57 public static void showDataSet(ArrayList<ArrayList<Double>> dataSet){ 58 for(ArrayList<Double> data:dataSet){ 59 System.out.println(data); 60 } 61 } 62 //获取最大值,用于求步长 63 public static double getMax(ArrayList<ArrayList<Double>> dataSet,int index){ 64 double max=-9999.0; 65 for(ArrayList<Double> data:dataSet){ 66 if(data.get(index)>max){ 67 max=data.get(index); 68 } 69 } 70 return max; 71 } 72 //获取最小值,用于求步长 73 public static double getMin(ArrayList<ArrayList<Double>> dataSet,int index){ 74 double min=9999.0; 75 for(ArrayList<Double> data:dataSet){ 76 if(data.get(index)<min){ 77 min=data.get(index); 78 } 79 } 80 return min; 81 } 82 83 //获取数据集中以该feature为特征,以thresh和conditions为value的叶子节点的决策树进行划分后得到的预测类别 84 public static ArrayList<Integer> getClassify(ArrayList<ArrayList<Double>> dataSet,int feature,double thresh,String condition){ 85 ArrayList<Integer> labelList=new ArrayList<Integer>(); 86 if(condition.compareTo("lt")==0){ 87 for(ArrayList<Double> data:dataSet){ 88 if(data.get(feature)<=thresh){ 89 labelList.add(1); 90 }else{ 91 labelList.add(-1); 92 } 93 } 94 }else{ 95 for(ArrayList<Double> data:dataSet){ 96 if(data.get(feature)>=thresh){ 97 labelList.add(1); 98 }else{ 99 labelList.add(-1); 100 } 101 } 102 } 103 return labelList; 104 } 105 //求预测类别与真实类别的加权误差 106 public static double getError(ArrayList<Integer> fake,ArrayList<Integer> real,ArrayList<Double> weights){ 107 double error=0; 108 109 int n=real.size(); 110 111 for(int i=0;i<fake.size();i++){ 112 if(fake.get(i)!=real.get(i)){ 113 error+=weights.get(i); 114 115 } 116 } 117 118 return error; 119 } 120 //构造一棵单节点的决策树,用一个Stump类来存储这些基本信息。 121 public static Stump buildStump(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelSet,ArrayList<Double> weights,int n){ 122 int featureNum=dataSet.get(0).size(); 123 124 int rowNum=dataSet.size(); 125 Stump stump=new Stump(); 126 double minError=999.0; 127 System.out.println("第"+n+"次迭代"); 128 for(int i=0;i<featureNum;i++){ 129 double min=getMin(dataSet,i); 130 double max=getMax(dataSet,i); 131 double step=(max-min)/(rowNum); 132 for(double j=min-step;j<=max+step;j=j+step){ 133 String[] conditions={"lt","gt"};//如果是lt,表示如果小于阀值则为真类,如果是gt,表示如果大于阀值则为正类 134 for(String condition:conditions){ 135 ArrayList<Integer> labelList=getClassify(dataSet,i,j,condition); 136 137 double error=Utils.getError(labelList,labelSet,weights); 138 if(error<minError){ 139 minError=error; 140 stump.dim=i; 141 stump.thresh=j; 142 stump.condition=condition; 143 stump.error=minError; 144 stump.labelList=labelList; 145 stump.factor=0.5*(Math.log((1-error)/error)); 146 } 147 148 } 149 } 150 151 } 152 153 return stump; 154 } 155 156 public static ArrayList<Double> getInitWeights(int n){ 157 double weight=1.0/n; 158 ArrayList<Double> weights=new ArrayList<Double>(); 159 for(int i=0;i<n;i++){ 160 weights.add(weight); 161 } 162 return weights; 163 } 164 //更新样本权值 165 public static ArrayList<Double> updateWeights(Stump stump,ArrayList<Integer> labelList,ArrayList<Double> weights){ 166 double Z=0; 167 ArrayList<Double> newWeights=new ArrayList<Double>(); 168 int row=labelList.size(); 169 double e=Math.E; 170 double factor=stump.factor; 171 for(int i=0;i<row;i++){ 172 Z+=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i)); 173 } 174 175 176 for(int i=0;i<row;i++){ 177 double weight=weights.get(i)*Math.pow(e,-factor*labelList.get(i)*stump.labelList.get(i))/Z; 178 newWeights.add(weight); 179 } 180 return newWeights; 181 } 182 //对加权误差累加 183 public static ArrayList<Double> InitAccWeightError(int n){ 184 ArrayList<Double> accError=new ArrayList<Double>(); 185 for(int i=0;i<n;i++){ 186 accError.add(0.0); 187 } 188 return accError; 189 } 190 191 public static ArrayList<Double> accWeightError(ArrayList<Double> accerror,Stump stump){ 192 ArrayList<Integer> t=stump.labelList; 193 double factor=stump.factor; 194 ArrayList<Double> newAccError=new ArrayList<Double>(); 195 for(int i=0;i<t.size();i++){ 196 double a=accerror.get(i)+factor*t.get(i); 197 newAccError.add(a); 198 } 199 return newAccError; 200 } 201 202 public static double calErrorRate(ArrayList<Double> accError,ArrayList<Integer> labelList){ 203 ArrayList<Integer> a=new ArrayList<Integer>(); 204 int wrong=0; 205 for(int i=0;i<accError.size();i++){ 206 if(accError.get(i)>0){ 207 if(labelList.get(i)==-1){ 208 wrong++; 209 } 210 }else if(labelList.get(i)==1){ 211 wrong++; 212 } 213 } 214 double error=wrong*1.0/accError.size(); 215 return error; 216 } 217 218 public static void showStumpList(ArrayList<Stump> G){ 219 for(Stump s:G){ 220 System.out.println(s); 221 System.out.println(" "); 222 } 223 } 224 } 225 226 227 public class Adaboost { 228 229 /** 230 * @param args 231 * @throws IOException 232 */ 233 234 public static ArrayList<Stump> AdaBoostTrain(ArrayList<ArrayList<Double>> dataSet,ArrayList<Integer> labelList){ 235 int row=labelList.size(); 236 ArrayList<Double> weights=Utils.getInitWeights(row); 237 ArrayList<Stump> G=new ArrayList<Stump>(); 238 ArrayList<Double> accError=Utils.InitAccWeightError(row); 239 int n=1; 240 while(true){ 241 Stump stump=Utils.buildStump(dataSet,labelList,weights,n);//求一棵误差率最小的单节点决策树 242 G.add(stump); 243 weights=Utils.updateWeights(stump,labelList,weights);//更新权值 244 accError=Utils.accWeightError(accError,stump);//将加权误差累加,因为这样不用再利用分类器再求了 245 double error=Utils.calErrorRate(accError,labelList); 246 if(error<0.001){ 247 break; 248 } 249 n++; 250 } 251 return G; 252 } 253 254 public static void main(String[] args) throws IOException { 255 // TODO Auto-generated method stub 256 String file="C:/Users/Administrator/Desktop/upload/AdaBoost1.txt"; 257 ArrayList<ArrayList<Double>> dataSet=Utils.loadDataSet(file); 258 ArrayList<Integer> labelSet=Utils.loadLabelSet(file); 259 ArrayList<Stump> G=AdaBoostTrain(dataSet,labelSet); 260 Utils.showStumpList(G); 261 System.out.println("finished"); 262 } 263 264 }
这里的数据采用的是统计学习方法中的数据
0 1 1 1 2 1 3 -1 4 -1 5 -1 6 1 7 1 8 1 9 -1
这里是单个特征的,也可以是多维数据,例如
1.0 2.1 1 2.0 1.1 1 1.3 1.0 -1 1.0 1.0 -1 2.0 1.0 1