Ranklib源码剖析--LambdaMart
Ranklib是一套优秀的Learning to Rank领域的开源实现,其中有实现了MART,RankNet,RankBoost,LambdaMart,Random Forest等模型。其中由微软发布的LambdaMART是IR业内常用的Learning to Rank模型,本文主要介绍Ranklib中的LambdaMART模型的具体实现,用以帮助理解paper中阐述的方法。本文是基于version2.3版本的Ranklib来介绍的。
LambdaMart的基本原理详见之前的博客:http://www.cnblogs.com/bentuwuying/p/6690836.html。要知道LambdaMart是基于MART的,而MART又是由若干棵regression tree组合而成的。所以,我们先来看看Ranklib中是如何实现regression tree的,以及在给定training data with labels的情况下,regression tree是如何拟合的。
1. regression tree
regression tree拟合给定training data的步骤总结概括如下:
RegressionTree nodes #限制一棵树的最大叶子节点数 minLeafSupport #控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂 root #根节点 leaves #叶子节点list 构造函数RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) 对各个类变量进行初始化 fit #对training data进行拟合regression tree 新建一个队列queue,用于按队列顺序(即按层遍历的顺序)进行分裂 初始化一个regression tree的根节点root root.split #根节点分裂 hist.findBestSplit #调用Split对象包含的FeatureHistogram对象的分裂方法(在该节点的已经统计好的特征统计直方图的基础上,寻找最佳分裂点,进行分裂,再计算左右子节点的特征统计直方图,并对左右子节点进行初始化) 判断deviance,为0则分裂不成功 根据samplingRate决定usedFeatures(分裂时需要使用的features的索引) 调用内部的findBestSplit方法 在一个节点上,在usedFeatures中,根据该节点的特征统计直方图,来进行分裂时feature和threshold的选择 S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight 对每个可选的划分点(feature和threshold组合),求最大的S值,对应于均方误差最小,是最优的划分点 判断划分是否成功,若S=-1,则分裂不成功 对该节点上的每个训练数据,根据最优分裂点,进行左右子节点的分配 初始化分裂后左右子节点各自的特征统计直方图 construct #一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数(当使用父节点来构造时,thresholds数组不变,但是sum和count数组需要重新构造) construct #一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数 计算本节点和左右子节点的均方误差 sp.set #调用FeatureHistogram对象所在的Split对象的方法 一般在该节点进行分裂完成后,设定分裂时的featureID,threshold,deviance 只有非叶子节点才会进行分裂(调用这个方法),所以只有非叶子节点的featureID不为-1,叶子节点由于没有调用这个方法,故featureID=-1 初始化左子节点(根据分裂到左子节点的训练数据索引数组,左子节点的特征统计直方图,左子节点的均方误差,左子节点的训练数据label之和),并设置到当前节点的左子节点变量上 初始化右子节点(根据分裂到右子节点的训练数据索引数组,右子节点的特征统计直方图,右子节点的均方误差,右子节点的训练数据label之和),并设置到当前节点的右子节点变量上 insert #将左右的子节点插入队列,用于下面遍历 按均方误差从大到小的顺序进行插入队列 循环:按队列顺序(即按层遍历的顺序)进行分裂,再将每次能够成功分裂的产生的两个子节点插入队列中 根据根节点root的leaves类方法(迭代遍历),设置regression tree的leaves类变量
下面是regression tree拟合过程中涉及到的几个类文件代码,关键部分都有添加了详细的注释。
1. FeatureHistogram
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.Arrays; 4 import java.util.List; 5 import java.util.Random; 6 import ciir.umass.edu.learning.DataPoint; 7 import ciir.umass.edu.utilities.MyThreadPool; 8 import ciir.umass.edu.utilities.WorkerThread; 9 /** 10 * @author vdang 11 */ 12 //特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时最优的feature和划分点 13 public class FeatureHistogram { 14 // 存放分裂时的featureIdx,thresholdIdx,以及评判是否最佳分裂的评分值sumLeft*sumLeft/countLeft + sumRight*sumRight/countRight 15 class Config { 16 int featureIdx = -1; 17 int thresholdIdx = -1; 18 double S = -1; 19 } 20 21 //Parameter 22 public static float samplingRate = 1; //采样率,用于对分裂时使用的feature个数进行采样,不使用所有的feature 23 24 //Variables 25 public int[] features = null; //feature数组,每个元素是一个feature id(fid) 26 public float[][] thresholds = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值 27 public double[][] sum = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是label之和,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二维数组大小与thresholds数组相同 28 public double sumResponse = 0; //所有的训练数据的label之和 29 public double sqSumResponse = 0; //所有的训练数据的label的平方和 30 public int[][] count = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是个数,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数,count二维数组大小与thresholds数组相同 31 public int[][] sampleToThresholdMap = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是索引,是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置 32 33 //whether to re-use its parents @sum and @count instead of cleaning up the parent and re-allocate for the children. 34 //@sum and @count of any intermediate tree node (except for root) can be re-used. 35 private boolean reuseParent = false; 36 37 public FeatureHistogram() 38 { 39 40 } 41 42 //FeatureHistogram构造函数(1-1),一般用作整棵树/根节点的feature histogram,计算该节点的特征统计直方图 43 //@samples: 训练数据 44 //@labels: 训练数据的label 45 //@sampleSortedIdx: 将样本根据特征排序,方便做树的分列时快速找出最优分列点,sorted list of samples by each feature, need initializing only once,初始化可见LambdaMART.java中的init() 46 //@features: 训练数据的特征集合 47 //@thresholds: 创建存放候选阈值(分列点)的表,a table of candidate thresolds for each feature, we will select the best tree split from these candidates later on 48 ,初始化可见LambdaMART.java中的init(),此二维数组的每一行的最后一列的值是后加的,为Float.MAX_VALUE 49 public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds) 50 { 51 this.features = features; 52 this.thresholds = thresholds; 53 54 sumResponse = 0; 55 sqSumResponse = 0; 56 57 sum = new double[features.length][]; 58 count = new int[features.length][]; 59 sampleToThresholdMap = new int[features.length][]; 60 61 //确定是否使用多线程计算 62 MyThreadPool p = MyThreadPool.getInstance(); 63 if(p.size() == 1) 64 construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1); 65 else 66 p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length); 67 } 68 //FeatureHistogram构造函数(1-2),被(1-1)调用 69 protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end) 70 { 71 for(int i=start;i<=end;i++) //对于每个feature 72 { 73 int fid = features[i]; // 获取feature id 74 //get the list of samples associated with this node (sorted in ascending order with respect to the current feature) 75 int[] idx = sampleSortedIdx[i]; //根据此feature下的value从小到大排序后的训练数据的索引数组 76 77 double sumLeft = 0; //累计此值,用于给sumLabel使用 78 float[] threshold = thresholds[i]; 79 double[] sumLabel = new double[threshold.length]; //对应前面sum二维数组的一行 80 int[] c = new int[threshold.length]; //对应前面count二维数组的一行 81 int[] stMap = new int[samples.length]; //对应前面sampleToThresholdMap二维数组的一行 82 83 int last = -1; 84 for(int t=0;t<threshold.length;t++) //对于每个可选的split阈值 85 { 86 int j=last+1; 87 //find the first sample that exceeds the current threshold 88 for(;j<idx.length;j++) 89 { 90 int k = idx[j]; //获取此DataPoint在samples数组中的索引 91 if(samples[k].getFeatureValue(fid) > threshold[t]) 92 break; 93 sumLeft += labels[k]; 94 if(i == 0) 95 { 96 sumResponse += labels[k]; 97 sqSumResponse += labels[k] * labels[k]; 98 } 99 stMap[k] = t; 100 } 101 last = j-1; 102 sumLabel[t] = sumLeft; 103 c[t] = last+1; 104 } 105 sampleToThresholdMap[i] = stMap; 106 sum[i] = sumLabel; 107 count[i] = c; 108 } 109 } 110 111 //update(1-1), update the histogram with these training labels (the feature histogram will be used to find the best tree split) 112 protected void update(double[] labels) 113 { 114 sumResponse = 0; 115 sqSumResponse = 0; 116 117 118 //确定是否使用多线程计算 119 MyThreadPool p = MyThreadPool.getInstance(); 120 if(p.size() == 1) 121 update(labels, 0, features.length-1); 122 else 123 p.execute(new Worker(this, labels), features.length); 124 } 125 126 //update(1-2),被(1-1)调用 127 protected void update(double[] labels, int start, int end) 128 { 129 for(int f=start;f<=end;f++) 130 Arrays.fill(sum[f], 0); 131 for(int k=0;k<labels.length;k++) 132 { 133 for(int f=start;f<=end;f++) 134 { 135 int t = sampleToThresholdMap[f][k]; 136 sum[f][t] += labels[k]; 137 if(f == 0) 138 { 139 sumResponse += labels[k]; 140 sqSumResponse += labels[k]*labels[k]; 141 } 142 //count doesn't change, so no need to re-compute 143 } 144 } 145 for(int f=start;f<=end;f++) 146 { 147 for(int t=1;t<thresholds[f].length;t++) 148 sum[f][t] += sum[f][t-1]; 149 } 150 } 151 152 //FeatureHistogram构造函数(2-1),一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数 153 //当使用父节点来构造时,thresholds数组不变,但是sum和count数组需要重新构造 154 //@soi: 使用的训练数据的索引位置 155 public void construct(FeatureHistogram parent, int[] soi, double[] labels) 156 { 157 this.features = parent.features; 158 this.thresholds = parent.thresholds; 159 sumResponse = 0; 160 sqSumResponse = 0; 161 sum = new double[features.length][]; 162 count = new int[features.length][]; 163 sampleToThresholdMap = parent.sampleToThresholdMap; 164 165 166 //确定是否使用多线程计算 167 MyThreadPool p = MyThreadPool.getInstance(); 168 if(p.size() == 1) 169 construct(parent, soi, labels, 0, features.length-1); 170 else 171 p.execute(new Worker(this, parent, soi, labels), features.length); 172 } 173 174 //FeatureHistogram构造函数(2-2),被(2-1)调用 175 protected void construct(FeatureHistogram parent, int[] soi, double[] labels, int start, int end) 176 { 177 //init 178 for(int i=start;i<=end;i++) 179 { 180 float[] threshold = thresholds[i]; 181 sum[i] = new double[threshold.length]; 182 count[i] = new int[threshold.length]; 183 Arrays.fill(sum[i], 0); 184 Arrays.fill(count[i], 0); 185 } 186 187 //update 188 for(int i=0;i<soi.length;i++) 189 { 190 int k = soi[i]; 191 for(int f=start;f<=end;f++) 192 { 193 int t = sampleToThresholdMap[f][k]; 194 sum[f][t] += labels[k]; 195 count[f][t] ++; 196 if(f == 0) 197 { 198 sumResponse += labels[k]; 199 sqSumResponse += labels[k]*labels[k]; 200 } 201 } 202 } 203 204 for(int f=start;f<=end;f++) 205 { 206 for(int t=1;t<thresholds[f].length;t++) 207 { 208 sum[f][t] += sum[f][t-1]; 209 count[f][t] += count[f][t-1]; 210 } 211 } 212 } 213 214 //FeatureHistogram构造函数(3-1),一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数 215 public void construct(FeatureHistogram parent, FeatureHistogram leftSibling, boolean reuseParent) 216 { 217 this.reuseParent = reuseParent; 218 this.features = parent.features; 219 this.thresholds = parent.thresholds; 220 sumResponse = parent.sumResponse - leftSibling.sumResponse; 221 sqSumResponse = parent.sqSumResponse - leftSibling.sqSumResponse; 222 223 if(reuseParent) 224 { 225 sum = parent.sum; 226 count = parent.count; 227 } 228 else 229 { 230 sum = new double[features.length][]; 231 count = new int[features.length][]; 232 } 233 sampleToThresholdMap = parent.sampleToThresholdMap; 234 235 //确定是否使用多线程计算 236 MyThreadPool p = MyThreadPool.getInstance(); 237 if(p.size() == 1) 238 construct(parent, leftSibling, 0, features.length-1); 239 else 240 p.execute(new Worker(this, parent, leftSibling), features.length); 241 } 242 243 //FeatureHistogram构造函数(3-2),被(3-1)调用 244 protected void construct(FeatureHistogram parent, FeatureHistogram leftSibling, int start, int end) 245 { 246 for(int f=start;f<=end;f++) 247 { 248 float[] threshold = thresholds[f]; 249 if(!reuseParent) 250 { 251 sum[f] = new double[threshold.length]; 252 count[f] = new int[threshold.length]; 253 } 254 for(int t=0;t<threshold.length;t++) 255 { 256 sum[f][t] = parent.sum[f][t] - leftSibling.sum[f][t]; 257 count[f][t] = parent.count[f][t] - leftSibling.count[f][t]; 258 } 259 } 260 } 261 262 //findBestSplit函数(1-2),被(1-1)调用。在一个节点上,在usedFeatures中,根据该节点的特征统计直方图,来进行分裂时feature和threshold的选择 263 protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end) 264 { 265 Config cfg = new Config(); 266 int totalCount = count[start][count[start].length-1]; 267 for(int f=start;f<=end;f++) 268 { 269 int i = usedFeatures[f]; 270 float[] threshold = thresholds[i]; 271 272 for(int t=0;t<threshold.length;t++) 273 { 274 int countLeft = count[i][t]; 275 int countRight = totalCount - countLeft; 276 if(countLeft < minLeafSupport || countRight < minLeafSupport) 277 continue; 278 279 double sumLeft = sum[i][t]; 280 double sumRight = sumResponse - sumLeft; 281 282 double S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight; 283 //求最大的S值,对应于均方误差最小,是最优的划分点 284 if(cfg.S < S) 285 { 286 cfg.S = S; 287 cfg.featureIdx = i; 288 cfg.thresholdIdx = t; 289 } 290 } 291 } 292 return cfg; 293 } 294 295 //findBestSplit函数(1-1),在该节点的已经统计好的特征统计直方图的基础上,寻找最佳分裂点,进行分裂,再计算左右子节点的特征统计直方图,并对左右子节点进行初始化 296 public boolean findBestSplit(Split sp, double[] labels, int minLeafSupport) 297 { 298 if(sp.getDeviance() >= 0.0 && sp.getDeviance() <= 0.0)//equals 0 299 return false;//no need to split 300 301 int[] usedFeatures = null;//index of the features to be used for tree splitting 302 if(samplingRate < 1)//need to do sub sampling (feature sampling) 303 { 304 int size = (int)(samplingRate * features.length); 305 usedFeatures = new int[size]; 306 //put all features into a pool 307 List<Integer> fpool = new ArrayList<Integer>(); 308 for(int i=0;i<features.length;i++) 309 fpool.add(i); 310 //do sampling, without replacement 311 Random r = new Random(); 312 for(int i=0;i<size;i++) 313 { 314 int sel = r.nextInt(fpool.size()); 315 usedFeatures[i] = fpool.get(sel); 316 fpool.remove(sel); 317 } 318 } 319 else//no sub-sampling, all features will be used 320 { 321 usedFeatures = new int[features.length]; 322 for(int i=0;i<features.length;i++) 323 usedFeatures[i] = i; 324 } 325 326 //find the best split 327 Config best = new Config(); 328 //确定是否使用多线程 329 MyThreadPool p = MyThreadPool.getInstance(); 330 if(p.size() == 1) 331 best = findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length-1); 332 else 333 { 334 WorkerThread[] workers = p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length); 335 for(int i=0;i<workers.length;i++) 336 { 337 Worker wk = (Worker)workers[i]; 338 if(best.S < wk.cfg.S) 339 best = wk.cfg; 340 } 341 } 342 343 if(best.S == -1)//unsplitable, for some reason... 344 return false; 345 346 //if(minS >= sp.getDeviance()) 347 //return null; 348 349 double[] sumLabel = sum[best.featureIdx]; 350 int[] sampleCount = count[best.featureIdx]; 351 352 double s = sumLabel[sumLabel.length-1]; 353 int c = sampleCount[sumLabel.length-1]; 354 355 double sumLeft = sumLabel[best.thresholdIdx]; 356 int countLeft = sampleCount[best.thresholdIdx]; 357 358 double sumRight = s - sumLeft; 359 int countRight = c - countLeft; 360 361 int[] left = new int[countLeft]; 362 int[] right = new int[countRight]; 363 int l = 0; 364 int r = 0; 365 int k = 0; 366 int[] idx = sp.getSamples(); 367 //对该节点上的每个训练数据,根据最优分裂点,进行左右子节点的分配 368 for(int j=0;j<idx.length;j++) 369 { 370 k = idx[j]; 371 if(sampleToThresholdMap[best.featureIdx][k] <= best.thresholdIdx)//go to the left 372 left[l++] = k; 373 else//go to the right 374 right[r++] = k; 375 } 376 377 //初始化分裂后左右子节点各自的特征统计直方图 378 FeatureHistogram lh = new FeatureHistogram(); 379 lh.construct(sp.hist, left, labels); //初始化左子节点的特征统计直方图 380 FeatureHistogram rh = new FeatureHistogram(); 381 rh.construct(sp.hist, lh, !sp.isRoot()); //初始化右子节点的特征统计直方图 382 double var = sqSumResponse - sumResponse * sumResponse / idx.length; //计算本节点的均方误差 383 double varLeft = lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length; //计算左子节点的均方误差 384 double varRight = rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length; //计算右子节点的均方误差 385 386 sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var); 387 sp.setLeft(new Split(left, lh, varLeft, sumLeft)); 388 sp.setRight(new Split(right, rh, varRight, sumRight)); 389 390 sp.clearSamples(); //清理本节点所属的sortedSampleIDs,samples,hist等数据 391 392 return true; 393 } 394 class Worker extends WorkerThread { 395 FeatureHistogram fh = null; 396 int type = -1; 397 398 //find best split (type == 0) 399 int[] usedFeatures = null; 400 int minLeafSup = -1; 401 Config cfg = null; 402 403 //update (type = 1) 404 double[] labels = null; 405 406 //construct (type = 2) 407 FeatureHistogram parent = null; 408 int[] soi = null; 409 410 //construct (type = 3) 411 FeatureHistogram leftSibling = null; 412 413 //construct (type = 4) 414 DataPoint[] samples; 415 int[][] sampleSortedIdx; 416 float[][] thresholds; 417 418 public Worker() 419 { 420 } 421 public Worker(FeatureHistogram fh, int[] usedFeatures, int minLeafSup) 422 { 423 type = 0; 424 this.fh = fh; 425 this.usedFeatures = usedFeatures; 426 this.minLeafSup = minLeafSup; 427 } 428 public Worker(FeatureHistogram fh, double[] labels) 429 { 430 type = 1; 431 this.fh = fh; 432 this.labels = labels; 433 } 434 public Worker(FeatureHistogram fh, FeatureHistogram parent, int[] soi, double[] labels) 435 { 436 type = 2; 437 this.fh = fh; 438 this.parent = parent; 439 this.soi = soi; 440 this.labels = labels; 441 } 442 public Worker(FeatureHistogram fh, FeatureHistogram parent, FeatureHistogram leftSibling) 443 { 444 type = 3; 445 this.fh = fh; 446 this.parent = parent; 447 this.leftSibling = leftSibling; 448 } 449 public Worker(FeatureHistogram fh, DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds) 450 { 451 type = 4; 452 this.fh = fh; 453 this.samples = samples; 454 this.labels = labels; 455 this.sampleSortedIdx = sampleSortedIdx; 456 this.thresholds = thresholds; 457 } 458 public void run() 459 { 460 if(type == 0) 461 cfg = fh.findBestSplit(usedFeatures, minLeafSup, start, end); 462 else if(type == 1) 463 fh.update(labels, start, end); 464 else if(type == 2) 465 fh.construct(parent, soi, labels, start, end); 466 else if(type == 3) 467 fh.construct(parent, leftSibling, start, end); 468 else if(type == 4) 469 fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end); 470 } 471 public WorkerThread clone() 472 { 473 Worker wk = new Worker(); 474 wk.fh = fh; 475 wk.type = type; 476 477 //find best split (type == 0) 478 wk.usedFeatures = usedFeatures; 479 wk.minLeafSup = minLeafSup; 480 //wk.cfg = cfg; 481 482 //update (type = 1) 483 wk.labels = labels; 484 485 //construct (type = 2) 486 wk.parent = parent; 487 wk.soi = soi; 488 489 //construct (type = 3) 490 wk.leftSibling = leftSibling; 491 492 //construct (type = 1) 493 wk.samples = samples; 494 wk.sampleSortedIdx = sampleSortedIdx; 495 wk.thresholds = thresholds; 496 497 return wk; 498 } 499 } 500 }
2. Split
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.List; 4 import ciir.umass.edu.learning.DataPoint; 5 /** 6 * 7 * @author vdang 8 * 9 */ 10 //Tree node,节点类,用于: 11 // 1)训练时候的分裂判断(利用FeatureHistogram类); 12 // 2)存储该节点的分裂规则(featureID,threshold)以及该节点的输出(avgLabel,deviance等) 13 public class Split { 14 //Key attributes of a split (tree node) 15 //存储该节点的分裂规则(featureID,threshold)以及该节点的输出(avgLabel,deviance等) 16 private int featureID = -1; 17 private float threshold = 0F; 18 private double avgLabel = 0.0F; 19 20 //Intermediate variables (ONLY used during learning) 21 //*DO NOT* attempt to access them once the training is done 22 private boolean isRoot = false; 23 private double sumLabel = 0.0; 24 private double sqSumLabel = 0.0; 25 private Split left = null; 26 private Split right = null; 27 private double deviance = 0F;//mean squared error "S" 28 private int[][] sortedSampleIDs = null; 29 public int[] samples = null;//训练时候,该节点上的训练数据集的索引 30 public FeatureHistogram hist = null;//训练时候,该节点上的训练数据集的特征统计直方图 31 32 public Split() 33 { 34 35 } 36 public Split(int featureID, float threshold, double deviance) 37 { 38 this.featureID = featureID; 39 this.threshold = threshold; 40 this.deviance = deviance; 41 } 42 public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel) 43 { 44 this.sortedSampleIDs = sortedSampleIDs; 45 this.deviance = deviance; 46 this.sumLabel = sumLabel; 47 this.sqSumLabel = sqSumLabel; 48 avgLabel = sumLabel/sortedSampleIDs[0].length; 49 } 50 public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel) 51 { 52 this.samples = samples; 53 this.hist = hist; 54 this.deviance = deviance; 55 this.sumLabel = sumLabel; 56 avgLabel = sumLabel/samples.length; 57 } 58 59 //一般在该节点进行分裂完成后,设定分裂时的featureID,threshold,deviance。 60 //只有非叶子节点才会进行分裂(调用这个方法),所以只有非叶子节点的featureID不为-1,叶子节点由于没有调用这个方法,故featureID=-1 61 public void set(int featureID, float threshold, double deviance) 62 { 63 this.featureID = featureID; 64 this.threshold = threshold; 65 this.deviance = deviance; 66 } 67 public void setLeft(Split s) 68 { 69 left = s; 70 } 71 public void setRight(Split s) 72 { 73 right = s; 74 } 75 public void setOutput(float output) 76 { 77 avgLabel = output; 78 } 79 80 public Split getLeft() 81 { 82 return left; 83 } 84 public Split getRight() 85 { 86 return right; 87 } 88 public double getDeviance() 89 { 90 return deviance; 91 } 92 public double getOutput() 93 { 94 return avgLabel; 95 } 96 97 //得到此节点(一般是根节点)下的所有叶子节点的list 98 //采用了递归的方法,碰到叶子节点(featureID=-1)则加入到list中,否则递归地调用leaves(list), 99 public List<Split> leaves() 100 { 101 List<Split> list = new ArrayList<Split>(); 102 leaves(list); 103 return list; 104 } 105 private void leaves(List<Split> leaves) 106 { 107 if(featureID == -1) 108 leaves.add(this); 109 else 110 { 111 left.leaves(leaves); 112 right.leaves(leaves); 113 } 114 } 115 116 //得到一个DataPoint在此节点(一般是根节点)下的最终落入(每层都按照分裂规则进入下一层)的叶子节点的输出值(avgLabel值) 117 public double eval(DataPoint dp) 118 { 119 Split n = this; 120 while(n.featureID != -1) 121 { 122 if(dp.getFeatureValue(n.featureID) <= n.threshold) 123 n = n.left; 124 else 125 n = n.right; 126 } 127 return n.avgLabel; 128 } 129 130 public String toString() 131 { 132 return toString(""); 133 } 134 public String toString(String indent) 135 { 136 String strOutput = indent + "<split>" + "\n"; 137 strOutput += getString(indent + "\t"); 138 strOutput += indent + "</split>" + "\n"; 139 return strOutput; 140 } 141 public String getString(String indent) 142 { 143 String strOutput = ""; 144 if(featureID == -1) 145 { 146 strOutput += indent + "<output> " + avgLabel + " </output>" + "\n"; 147 } 148 else 149 { 150 strOutput += indent + "<feature> " + featureID + " </feature>" + "\n"; 151 strOutput += indent + "<threshold> " + threshold + " </threshold>" + "\n"; 152 strOutput += indent + "<split pos=\"left\">" + "\n"; 153 strOutput += left.getString(indent + "\t"); 154 strOutput += indent + "</split>" + "\n"; 155 strOutput += indent + "<split pos=\"right\">" + "\n"; 156 strOutput += right.getString(indent + "\t"); 157 strOutput += indent + "</split>" + "\n"; 158 } 159 return strOutput; 160 } 161 //Internal functions(ONLY used during learning) 162 //*DO NOT* attempt to call them once the training is done 163 //*重要*,训练时候,在该节点上进行分裂,调用了该节点的特征统计直方图对象的方法findBestSplit 164 public boolean split(double[] trainingLabels, int minLeafSupport) 165 { 166 return hist.findBestSplit(this, trainingLabels, minLeafSupport); 167 } 168 public int[] getSamples() 169 { 170 if(sortedSampleIDs != null) 171 return sortedSampleIDs[0]; 172 return samples; 173 } 174 public int[][] getSampleSortedIndex() 175 { 176 return sortedSampleIDs; 177 } 178 public double getSumLabel() 179 { 180 return sumLabel; 181 } 182 public double getSqSumLabel() 183 { 184 return sqSumLabel; 185 } 186 public void clearSamples() 187 { 188 sortedSampleIDs = null; 189 samples = null; 190 hist = null; 191 } 192 public void setRoot(boolean isRoot) 193 { 194 this.isRoot = isRoot; 195 } 196 public boolean isRoot() 197 { 198 return isRoot; 199 } 200 }
3. RegressionTree
1 package ciir.umass.edu.learning.tree; 2 import java.util.ArrayList; 3 import java.util.List; 4 import ciir.umass.edu.learning.DataPoint; 5 /** 6 * @author vdang 7 */ 8 //回归树类 9 public class RegressionTree { 10 11 //Parameters 12 protected int nodes = 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport) 13 protected int minLeafSupport = 1; //控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂 14 15 //Member variables and functions 16 protected Split root = null; //根节点 17 protected List<Split> leaves = null; //叶子节点list 18 19 protected DataPoint[] trainingSamples = null; 20 protected double[] trainingLabels = null; 21 protected int[] features = null; 22 protected float[][] thresholds = null; //二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值 23 protected int[] index = null; 24 protected FeatureHistogram hist = null; 25 26 public RegressionTree(Split root) 27 { 28 this.root = root; 29 leaves = root.leaves(); 30 } 31 public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport) 32 { 33 this.nodes = nLeaves; 34 this.trainingSamples = trainingSamples; 35 this.trainingLabels = labels; 36 this.hist = hist; 37 this.minLeafSupport = minLeafSupport; 38 index = new int[trainingSamples.length]; 39 for(int i=0;i<trainingSamples.length;i++) 40 index[i] = i; 41 } 42 43 /** 44 * Fit the tree from the specified training data 45 */ 46 public void fit() 47 { 48 List<Split> queue = new ArrayList<Split>(); //用于按队列顺序(即按层遍历的顺序)进行分裂 49 root = new Split(index, hist, Float.MAX_VALUE, 0); //回归树的根节点 50 root.setRoot(true); 51 root.split(trainingLabels, minLeafSupport); //根节点分裂1次,下面多了2个子节点 52 insert(queue, root.getLeft()); //将左子节点插入队列,用于下面遍历 53 insert(queue, root.getRight()); //将右子节点插入队列,用于下面遍历 54 //循环:按队列顺序(即按层遍历的顺序)进行分裂,再将每次能够成功分裂的产生的两个子节点插入队列中 55 int taken = 0; 56 while( (nodes == -1 || taken + queue.size() < nodes) && queue.size() > 0) 57 { 58 Split leaf = queue.get(0); 59 queue.remove(0); 60 61 if(leaf.getSamples().length < 2 * minLeafSupport) 62 { 63 taken++; 64 continue; 65 } 66 67 if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)==0; or after-split variance is higher than before) 对每个遍历到的节点,进行1次分裂,下面多了2个子节点 68 taken++; 69 else 70 { 71 insert(queue, leaf.getLeft()); //将左子节点插入队列,用于下面遍历 72 insert(queue, leaf.getRight()); //将右子节点插入队列,用于下面遍历 73 } 74 } 75 leaves = root.leaves(); 76 } 77 78 /** 79 * Get the tree output for the input sample 80 * @param dp 81 * @return 82 */ 83 public double eval(DataPoint dp) 84 { 85 return root.eval(dp); 86 } 87 /** 88 * Retrieve all leave nodes in the tree 89 * @return 90 */ 91 public List<Split> leaves() 92 { 93 return leaves; 94 } 95 /** 96 * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory 97 */ 98 public void clearSamples() 99 { 100 trainingSamples = null; 101 trainingLabels = null; 102 features = null; 103 thresholds = null; 104 index = null; 105 hist = null; 106 for(int i=0;i<leaves.size();i++) 107 leaves.get(i).clearSamples(); 108 } 109 110 /** 111 * Generate the string representation of the tree 112 */ 113 public String toString() 114 { 115 if(root != null) 116 return root.toString(); 117 return ""; 118 } 119 public String toString(String indent) 120 { 121 if(root != null) 122 return root.toString(indent); 123 return ""; 124 } 125 126 public double variance() 127 { 128 double var = 0; 129 for(int i=0;i<leaves.size();i++) 130 var += leaves.get(i).getDeviance(); 131 return var; 132 } 133 protected void insert(List<Split> ls, Split s) 134 { 135 int i=0; 136 while(i < ls.size()) 137 { 138 if(ls.get(i).getDeviance() > s.getDeviance()) //按均方误差从大到小的顺序进行插入队列 139 i++; 140 else 141 break; 142 } 143 ls.add(i, s); 144 } 145 }
2. LambdaMart
LambdaMart模型训练过程总结概括如下:
1 LambdaMart 2 init 3 初始化训练数据:martSamples,modelScores,pseudoResponses,weights 4 将样本根据特征排序,方便做树的分裂时快速找出最优分裂点:sortedIdx 5 初始化二维数组:thresholds(第一维是feature,下标是相应的features的下标,不是feature id;第二维是阈值,个数为所有训练数据在此feature上的value的去重个数,从小到大排序的不重复值,用于对此节点的训练数据在此feature上分裂时可选的feature value阈值) 6 hist.construct #根据训练数据以及thresholds二维数组,初始化一个FeatureHistogram对象,用于构造整体数据的特征统计直方图,用于在根节点上进行分裂 7 初始化: 8 sum #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是label之和,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和,sum二维数组大小与thresholds数组相同 9 count #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是个数,是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数,count二维数组大小与thresholds数组相同 10 sampleToThresholdMap #二维数组,第一维是feature,下标是相应的features的下标,不是feature id;第二维是索引,是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置 11 sumResponse #所有的训练数据的label之和 12 sqSumResponse #所有的训练数据的label的平方和 13 learn 14 初始化一个Ensemble对象ensemble 15 开始Gradient Boosting过程,即依次构造若干棵regression tree: 16 computePseudoResponses #计算本轮迭代中,每个instance需要拟合的pseudo responses值(即梯度值,lambda) 17 根据LambdaMart的梯度计算公式进行计算 18 hist.update #根据本轮迭代中计算得到的pseudo responses值(即梯度值,lambda),更新特征统计直方图,因为只改变了training data中每个instance的label,而其他值(如features)并未改变 19 初始化一棵regression tree(根据训练数据和特征统计直方图) 20 rt.fit #用regression tree对训练数据+本轮迭代中的pseudo responses值(即梯度值,lambda)进行拟合 21 将本轮迭代拟合产生的regression tree加入到ensembel对象中 22 updateTreeOutput #更新本轮迭代中拟合数据的regression tree的各个叶子节点的输出 23 计算本轮迭代后(新regression tree已经加入到集成模型中),training data中各个instance的预测分:modelScores 24 computeModelScoreOnTraining #计算本轮迭代后,最新模型对于training data总体的排序评价分(例如NDCG) 25 计算本轮迭代后(新regression tree已经加入到集成模型中),validation data中各个instance的预测分:modelScoresOnValidation 26 computeModelScoreOnValidation #计算本轮迭代后,最新模型对于validation data总体的排序评价分(例如NDCG) 27 更新在validation data上的历次各个模型的最优排序评价分:bestScoreOnValidationData,以及最优模型编号:bestModelOnValidation 28 如果在连续若干轮迭代中,模型在validation data上的排序评价分都没有提高,则终止迭代 29 回滚到在验证集上的最优模型 30 计算最优模型在training data和validation data上的排序评价分
下面是LambdaMart训练过程的代码,关键部分都有添加了详细的注释。
1. LambdaMART
1 package ciir.umass.edu.learning.tree; 2 import ciir.umass.edu.learning.DataPoint; 3 import ciir.umass.edu.learning.RankList; 4 import ciir.umass.edu.learning.Ranker; 5 import ciir.umass.edu.metric.MetricScorer; 6 import ciir.umass.edu.utilities.MergeSorter; 7 import ciir.umass.edu.utilities.MyThreadPool; 8 import ciir.umass.edu.utilities.RankLibError; 9 import ciir.umass.edu.utilities.SimpleMath; 10 import java.io.BufferedReader; 11 import java.io.StringReader; 12 import java.util.ArrayList; 13 import java.util.Arrays; 14 import java.util.List; 15 /** 16 * @author vdang 17 * 18 * This class implements LambdaMART. 19 * Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 20 * Journal of Information Retrieval, 2007. 21 */ 22 public class LambdaMART extends Ranker { 23 //Parameters 24 public static int nTrees = 1000;//the number of trees 25 public static float learningRate = 0.1F;//or shrinkage 26 public static int nThreshold = 256; 27 public static int nRoundToStopEarly = 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 28 public static int nTreeLeaves = 10; 29 public static int minLeafSupport = 1; 30 31 //for debugging 32 public static int gcCycle = 100; 33 34 //Local variables 35 protected float[][] thresholds = null; 36 protected Ensemble ensemble = null; 37 protected double[] modelScores = null;//on training data 38 39 protected double[][] modelScoresOnValidation = null; 40 protected int bestModelOnValidation = Integer.MAX_VALUE-2; 41 42 //Training instances prepared for MART 43 protected DataPoint[] martSamples = null;//Need initializing only once 44 protected int[][] sortedIdx = null;//sorted list of samples in @martSamples by each feature -- Need initializing only once 45 protected FeatureHistogram hist = null; 46 protected double[] pseudoResponses = null;//different for each iteration 47 protected double[] weights = null;//different for each iteration 48 49 public LambdaMART() 50 { 51 } 52 public LambdaMART(List<RankList> samples, int[] features, MetricScorer scorer) 53 { 54 super(samples, features, scorer); 55 } 56 57 public void init() 58 { 59 PRINT("Initializing... "); 60 //initialize samples for MART 61 int dpCount = 0; 62 for(int i=0;i<samples.size();i++) 63 { 64 RankList rl = samples.get(i); 65 dpCount += rl.size(); 66 } 67 int current = 0; 68 martSamples = new DataPoint[dpCount]; 69 modelScores = new double[dpCount]; 70 pseudoResponses = new double[dpCount]; 71 weights = new double[dpCount]; 72 for(int i=0;i<samples.size();i++) 73 { 74 RankList rl = samples.get(i); 75 for(int j=0;j<rl.size();j++) 76 { 77 martSamples[current+j] = rl.get(j); 78 modelScores[current+j] = 0.0F; 79 pseudoResponses[current+j] = 0.0F; 80 weights[current+j] = 0; 81 } 82 current += rl.size(); 83 } 84 85 //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on. 86 // 将样本根据特征排序,方便做树的分裂时快速找出最优分裂点 87 sortedIdx = new int[features.length][]; 88 MyThreadPool p = MyThreadPool.getInstance(); 89 if(p.size() == 1)//single-thread 90 sortSamplesByFeature(0, features.length-1); 91 else//multi-thread 92 { 93 int[] partition = p.partition(features.length); 94 for(int i=0;i<partition.length-1;i++) 95 p.execute(new SortWorker(this, partition[i], partition[i+1]-1)); 96 p.await(); 97 } 98 99 //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates // 创建存放候选阈值(分裂点)的表 100 thresholds = new float[features.length][]; 101 for(int f=0;f<features.length;f++) 102 { 103 //For this feature, keep track of the list of unique values and the max/min 104 List<Float> values = new ArrayList<Float>(); 105 float fmax = Float.NEGATIVE_INFINITY; 106 float fmin = Float.MAX_VALUE; 107 for(int i=0;i<martSamples.length;i++) 108 { 109 int k = sortedIdx[f][i];//get samples sorted with respect to this feature 110 float fv = martSamples[k].getFeatureValue(features[f]); 111 values.add(fv); 112 if(fmax < fv) 113 fmax = fv; 114 if(fmin > fv) 115 fmin = fv; 116 //skip all samples with the same feature value 117 int j=i+1; 118 while(j < martSamples.length) 119 { 120 if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) > fv) 121 break; 122 j++; 123 } 124 i = j-1;//[i, j] gives the range of samples with the same feature value 125 } 126 127 if(values.size() <= nThreshold || nThreshold == -1) 128 { 129 thresholds[f] = new float[values.size()+1]; 130 for(int i=0;i<values.size();i++) 131 thresholds[f][i] = values.get(i); 132 thresholds[f][values.size()] = Float.MAX_VALUE; 133 } 134 else 135 { 136 float step = (Math.abs(fmax - fmin))/nThreshold; 137 thresholds[f] = new float[nThreshold+1]; 138 thresholds[f][0] = fmin; 139 for(int j=1;j<nThreshold;j++) 140 thresholds[f][j] = thresholds[f][j-1] + step; 141 thresholds[f][nThreshold] = Float.MAX_VALUE; 142 } 143 } 144 145 if(validationSamples != null) 146 { 147 modelScoresOnValidation = new double[validationSamples.size()][]; 148 for(int i=0;i<validationSamples.size();i++) 149 { 150 modelScoresOnValidation[i] = new double[validationSamples.get(i).size()]; 151 Arrays.fill(modelScoresOnValidation[i], 0); 152 } 153 } 154 155 //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on) 156 // 计算特征直方图,加速寻找分裂点 157 hist = new FeatureHistogram(); 158 hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds); 159 //we no longer need the sorted indexes of samples 160 sortedIdx = null; 161 162 System.gc(); 163 PRINTLN("[Done]"); 164 } 165 public void learn() 166 { 167 ensemble = new Ensemble(); 168 169 PRINTLN("---------------------------------"); 170 PRINTLN("Training starts..."); 171 PRINTLN("---------------------------------"); 172 PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", scorer.name()+"-T", scorer.name()+"-V"}); 173 PRINTLN("---------------------------------"); 174 175 //Start the gradient boosting process 176 for(int m=0; m<nTrees; m++) 177 { 178 PRINT(new int[]{7}, new String[]{(m+1)+""}); 179 180 //Compute lambdas (which act as the "pseudo responses") 181 //Create training instances for MART: 182 // - Each document is a training sample 183 // - The lambda for this document serves as its training label 184 // 计算lambdas (pseudo responses) 185 computePseudoResponses(); 186 187 //update the histogram with these training labels (the feature histogram will be used to find the best tree split) 188 // 根据新的label更新特征直方图 189 hist.update(pseudoResponses); 190 191 //Fit a regression tree 192 // 回归决策树 193 RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport); 194 rt.fit(); 195 196 //Add this tree to the ensemble (our model) 197 // 将新生成的树加入模型 198 ensemble.add(rt, learningRate); 199 //update the outputs of the tree (with gamma computed using the Newton-Raphson method) 200 // 更新树的输出 201 updateTreeOutput(rt); 202 203 //Update the model's outputs on all training samples 204 // 更新所有训练样本的模型输出 205 List<Split> leaves = rt.leaves(); 206 for(int i=0;i<leaves.size();i++) 207 { 208 Split s = leaves.get(i); 209 int[] idx = s.getSamples(); 210 for(int j=0;j<idx.length;j++) 211 modelScores[idx[j]] += learningRate * s.getOutput(); 212 } 213 //clear references to data that is no longer used 214 rt.clearSamples(); 215 216 //beg the garbage collector to work... 217 if(m % gcCycle == 0) 218 System.gc();//this call is expensive. We shouldn't do it too often. 219 //Evaluate the current model 220 // 评价模型 221 scoreOnTrainingData = computeModelScoreOnTraining(); 222 //**** NOTE **** 223 //The above function to evaluate the current model on the training data is equivalent to a single call: 224 // 225 // scoreOnTrainingData = scorer.score(rank(samples); 226 // 227 //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model 228 //on the entire training set). 229 230 PRINT(new int[]{9}, new String[]{SimpleMath.round(scoreOnTrainingData, 4) + ""}); 231 232 //Evaluate the current model on the validation data (if available) 233 if(validationSamples != null) 234 { 235 //Update the model's scores on all validation samples 236 for(int i=0;i<modelScoresOnValidation.length;i++) 237 for(int j=0;j<modelScoresOnValidation[i].length;j++) 238 modelScoresOnValidation[i][j] += learningRate * rt.eval(validationSamples.get(i).get(j)); 239 240 //again, equivalent to scoreOnValidation=scorer.score(rank(validationSamples)), but more efficient since we use the cached models' outputs 241 double score = computeModelScoreOnValidation(); 242 243 PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""}); 244 if(score > bestScoreOnValidationData) 245 { 246 bestScoreOnValidationData = score; 247 bestModelOnValidation = ensemble.treeCount()-1; 248 } 249 } 250 251 PRINTLN(""); 252 253 //Should we stop early? 254 // 检验是否提前结束 255 if(m - bestModelOnValidation > nRoundToStopEarly) 256 break; 257 } 258 259 //Rollback to the best model observed on the validation data 260 // 回滚到在验证集上的最优模型 261 while(ensemble.treeCount() > bestModelOnValidation+1) 262 ensemble.remove(ensemble.treeCount()-1); 263 264 //Finishing up 265 scoreOnTrainingData = scorer.score(rank(samples)); 266 PRINTLN("---------------------------------"); 267 PRINTLN("Finished sucessfully."); 268 PRINTLN(scorer.name() + " on training data: " + SimpleMath.round(scoreOnTrainingData, 4)); 269 if(validationSamples != null) 270 { 271 bestScoreOnValidationData = scorer.score(rank(validationSamples)); 272 PRINTLN(scorer.name() + " on validation data: " + SimpleMath.round(bestScoreOnValidationData, 4)); 273 } 274 PRINTLN("---------------------------------"); 275 } 276 public double eval(DataPoint dp) 277 { 278 return ensemble.eval(dp); 279 } 280 public Ranker createNew() 281 { 282 return new LambdaMART(); 283 } 284 public String toString() 285 { 286 return ensemble.toString(); 287 } 288 public String model() 289 { 290 String output = "## " + name() + "\n"; 291 output += "## No. of trees = " + nTrees + "\n"; 292 output += "## No. of leaves = " + nTreeLeaves + "\n"; 293 output += "## No. of threshold candidates = " + nThreshold + "\n"; 294 output += "## Learning rate = " + learningRate + "\n"; 295 output += "## Stop early = " + nRoundToStopEarly + "\n"; 296 output += "\n"; 297 output += toString(); 298 return output; 299 } 300 @Override 301 public void loadFromString(String fullText) 302 { 303 try { 304 String content = ""; 305 //String model = ""; 306 StringBuffer model = new StringBuffer (); 307 BufferedReader in = new BufferedReader(new StringReader(fullText)); 308 while((content = in.readLine()) != null) 309 { 310 content = content.trim(); 311 if(content.length() == 0) 312 continue; 313 if(content.indexOf("##")==0) 314 continue; 315 //actual model component 316 //model += content; 317 model.append (content); 318 } 319 in.close(); 320 //load the ensemble 321 ensemble = new Ensemble(model.toString()); 322 features = ensemble.getFeatures(); 323 } 324 catch(Exception ex) 325 { 326 throw RankLibError.create("Error in LambdaMART::load(): ", ex); 327 } 328 } 329 public void printParameters() 330 { 331 PRINTLN("No. of trees: " + nTrees); 332 PRINTLN("No. of leaves: " + nTreeLeaves); 333 PRINTLN("No. of threshold candidates: " + nThreshold); 334 PRINTLN("Min leaf support: " + minLeafSupport); 335 PRINTLN("Learning rate: " + learningRate); 336 PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data"); 337 } 338 public String name() 339 { 340 return "LambdaMART"; 341 } 342 public Ensemble getEnsemble() 343 { 344 return ensemble; 345 } 346 347 protected void computePseudoResponses() 348 { 349 Arrays.fill(pseudoResponses, 0F); 350 Arrays.fill(weights, 0); 351 MyThreadPool p = MyThreadPool.getInstance(); 352 if(p.size() == 1)//single-thread 353 computePseudoResponses(0, samples.size()-1, 0); 354 else //multi-threading 355 { 356 List<LambdaComputationWorker> workers = new ArrayList<LambdaMART.LambdaComputationWorker>(); 357 //divide the entire dataset into chunks of equal size for each worker thread 358 int[] partition = p.partition(samples.size()); 359 int current = 0; 360 for(int i=0;i<partition.length-1;i++) 361 { 362 //execute the worker 363 LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i+1]-1, current); 364 workers.add(wk);//keep it so we can get back results from it later on 365 p.execute(wk); 366 367 if(i < partition.length-2) 368 for(int j=partition[i]; j<=partition[i+1]-1;j++) 369 current += samples.get(j).size(); 370 } 371 372 //wait for all workers to complete before we move on to the next stage 373 p.await(); 374 } 375 } 376 protected void computePseudoResponses(int start, int end, int current) 377 { 378 int cutoff = scorer.getK(); 379 //compute the lambda for each document (a.k.a "pseudo response") 380 for(int i=start;i<=end;i++) 381 { 382 RankList orig = samples.get(i); 383 int[] idx = MergeSorter.sort(modelScores, current, current+orig.size()-1, false); 384 RankList rl = new RankList(orig, idx, current); 385 double[][] changes = scorer.swapChange(rl); 386 //NOTE: j, k are indices in the sorted (by modelScore) list, not the original 387 // ==> need to map back with idx[j] and idx[k] 388 for(int j=0;j<rl.size();j++) 389 { 390 DataPoint p1 = rl.get(j); 391 int mj = idx[j]; 392 for(int k=0;k<rl.size();k++) 393 { 394 if(j > cutoff && k > cutoff)//swaping these pair won't result in any change in target measures since they're below the cut-off point 395 break; 396 DataPoint p2 = rl.get(k); 397 int mk = idx[k]; 398 if(p1.getLabel() > p2.getLabel()) 399 { 400 double deltaNDCG = Math.abs(changes[j][k]); 401 if(deltaNDCG > 0) 402 { 403 double rho = 1.0 / (1 + Math.exp(modelScores[mj] - modelScores[mk])); 404 double lambda = rho * deltaNDCG; 405 pseudoResponses[mj] += lambda; 406 pseudoResponses[mk] -= lambda; 407 double delta = rho * (1.0 - rho) * deltaNDCG; 408 weights[mj] += delta; 409 weights[mk] += delta; 410 } 411 } 412 } 413 } 414 current += orig.size(); 415 } 416 } 417 protected void updateTreeOutput(RegressionTree rt) 418 { 419 List<Split> leaves = rt.leaves(); 420 for(int i=0;i<leaves.size();i++) 421 { 422 float s1 = 0F; 423 float s2 = 0F; 424 Split s = leaves.get(i); 425 int[] idx = s.getSamples(); 426 for(int j=0;j<idx.length;j++) 427 { 428 int k = idx[j]; 429 s1 += pseudoResponses[k]; 430 s2 += weights[k]; 431 } 432 if(s2 == 0) 433 s.setOutput(0); 434 else 435 s.setOutput(s1/s2); 436 } 437 } 438 protected int[] sortSamplesByFeature(DataPoint[] samples, int fid) 439 { 440 double[] score = new double[samples.length]; 441 for(int i=0;i<samples.length;i++) 442 score[i] = samples[i].getFeatureValue(fid); 443 int[] idx = MergeSorter.sort(score, true); 444 return idx; 445 } 446 /** 447 * This function is equivalent to the inherited function rank(...), but it uses the cached model's outputs instead of computing them from scratch. 448 * @param rankListIndex 449 * @param current 450 * @return 451 */ 452 protected RankList rank(int rankListIndex, int current) 453 { 454 RankList orig = samples.get(rankListIndex); 455 double[] scores = new double[orig.size()]; 456 for(int i=0;i<scores.length;i++) 457 scores[i] = modelScores[current+i]; 458 int[] idx = MergeSorter.sort(scores, false); 459 return new RankList(orig, idx); 460 } 461 protected float computeModelScoreOnTraining() 462 { 463 /*float s = 0; 464 int current = 0; 465 MyThreadPool p = MyThreadPool.getInstance(); 466 if(p.size() == 1)//single-thread 467 s = computeModelScoreOnTraining(0, samples.size()-1, current); 468 else 469 { 470 List<Worker> workers = new ArrayList<Worker>(); 471 //divide the entire dataset into chunks of equal size for each worker thread 472 int[] partition = p.partition(samples.size()); 473 for(int i=0;i<partition.length-1;i++) 474 { 475 //execute the worker 476 Worker wk = new Worker(this, partition[i], partition[i+1]-1, current); 477 workers.add(wk);//keep it so we can get back results from it later on 478 p.execute(wk); 479 480 if(i < partition.length-2) 481 for(int j=partition[i]; j<=partition[i+1]-1;j++) 482 current += samples.get(j).size(); 483 } 484 //wait for all workers to complete before we move on to the next stage 485 p.await(); 486 for(int i=0;i<workers.size();i++) 487 s += workers.get(i).score; 488 }*/ 489 float s = computeModelScoreOnTraining(0, samples.size()-1, 0); 490 s = s / samples.size(); 491 return s; 492 } 493 protected float computeModelScoreOnTraining(int start, int end, int current) 494 { 495 float s = 0; 496 int c = current; 497 for(int i=start;i<=end;i++) 498 { 499 s += scorer.score(rank(i, c)); 500 c += samples.get(i).size(); 501 } 502 return s; 503 } 504 protected float computeModelScoreOnValidation() 505 { 506 /*float score = 0; 507 MyThreadPool p = MyThreadPool.getInstance(); 508 if(p.size() == 1)//single-thread 509 score = computeModelScoreOnValidation(0, validationSamples.size()-1); 510 else 511 { 512 List<Worker> workers = new ArrayList<Worker>(); 513 //divide the entire dataset into chunks of equal size for each worker thread 514 int[] partition = p.partition(validationSamples.size()); 515 for(int i=0;i<partition.length-1;i++) 516 { 517 //execute the worker 518 Worker wk = new Worker(this, partition[i], partition[i+1]-1); 519 workers.add(wk);//keep it so we can get back results from it later on 520 p.execute(wk); 521 } 522 //wait for all workers to complete before we move on to the next stage 523 p.await(); 524 for(int i=0;i<workers.size();i++) 525 score += workers.get(i).score; 526 }*/ 527 float score = computeModelScoreOnValidation(0, validationSamples.size()-1); 528 return score/validationSamples.size(); 529 } 530 protected float computeModelScoreOnValidation(int start, int end) 531 { 532 float score = 0; 533 for(int i=start;i<=end;i++) 534 { 535 int[] idx = MergeSorter.sort(modelScoresOnValidation[i], false); 536 score += scorer.score(new RankList(validationSamples.get(i), idx)); 537 } 538 return score; 539 } 540 541 protected void sortSamplesByFeature(int fStart, int fEnd) 542 { 543 for(int i=fStart;i<=fEnd; i++) 544 sortedIdx[i] = sortSamplesByFeature(martSamples, features[i]); 545 } 546 //For multi-threading processing 547 class SortWorker implements Runnable { 548 LambdaMART ranker = null; 549 int start = -1; 550 int end = -1; 551 SortWorker(LambdaMART ranker, int start, int end) 552 { 553 this.ranker = ranker; 554 this.start = start; 555 this.end = end; 556 } 557 public void run() 558 { 559 ranker.sortSamplesByFeature(start, end); 560 } 561 } 562 class LambdaComputationWorker implements Runnable { 563 LambdaMART ranker = null; 564 int rlStart = -1; 565 int rlEnd = -1; 566 int martStart = -1; 567 LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) 568 { 569 this.ranker = ranker; 570 this.rlStart = rlStart; 571 this.rlEnd = rlEnd; 572 this.martStart = martStart; 573 } 574 public void run() 575 { 576 ranker.computePseudoResponses(rlStart, rlEnd, martStart); 577 } 578 } 579 class Worker implements Runnable { 580 LambdaMART ranker = null; 581 int rlStart = -1; 582 int rlEnd = -1; 583 int martStart = -1; 584 int type = -1; 585 586 //compute score on validation 587 float score = 0; 588 589 Worker(LambdaMART ranker, int rlStart, int rlEnd) 590 { 591 type = 3; 592 this.ranker = ranker; 593 this.rlStart = rlStart; 594 this.rlEnd = rlEnd; 595 } 596 Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) 597 { 598 type = 4; 599 this.ranker = ranker; 600 this.rlStart = rlStart; 601 this.rlEnd = rlEnd; 602 this.martStart = martStart; 603 } 604 public void run() 605 { 606 if(type == 4) 607 score = ranker.computeModelScoreOnTraining(rlStart, rlEnd, martStart); 608 else if(type == 3) 609 score = ranker.computeModelScoreOnValidation(rlStart, rlEnd); 610 } 611 } 612 }
版权声明:
本文由笨兔勿应所有,发布于http://www.cnblogs.com/bentuwuying。如果转载,请注明出处,在未经作者同意下将本文用于商业用途,将追究其法律责任。