SCIKIT-LEARN与GBDT使用案例

http://blog.csdn.net/superzrx/article/details/47073847

安装

SCIKIT-LEARN是一个基于Python/numpy/scipy的机器学习库 
windows下最简单的安装方式是使用winpython进行安装 
WinPython地址

GBDT使用

这段代码展示了一个简单的GBDT调用过程 
数据维数24,训练数据1990,测试数据221

import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
gbdt=GradientBoostingRegressor(
  loss='ls'
, learning_rate=0.1
, n_estimators=100
, subsample=1
, min_samples_split=2
, min_samples_leaf=1
, max_depth=3
, init=None
, random_state=None
, max_features=None
, alpha=0.9
, verbose=0
, max_leaf_nodes=None
, warm_start=False
)
train_feat=np.genfromtxt("train_feat.txt",dtype=np.float32)
train_id=np.genfromtxt("train_id.txt",dtype=np.float32)
test_feat=np.genfromtxt("test_feat.txt",dtype=np.float32)
test_id=np.genfromtxt("test_id.txt",dtype=np.float32)
print train_feat.shape,rain_id.shape,est_feat.shape,est_id.shape
gbdt.fit(train_feat,train_id)
pred=gbdt.predict(test_feat)
total_err=0
for i in range(pred.shape[0]):
    print pred[i],test_id[i]
    err=(pred[i]-test_id[i])/test_id[i]
    total_err+=err*err
print total_err/pred.shape[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

train_id.txt示例

320
361
364
336
358
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5

train_feat.txt示例

0.00598802 0.569231 0.647059 0.95122 -0.225434 0.837989 0.357258 -0.0030581 -0.383475
0.161677 0.743195 0.682353 0.960976 -0.0867052 0.780527 0.282945 0.149847 -0.0529661 
0.113772 0.744379 0.541176 0.990244 -0.00578035 0.721468 0.43411 -0.318043 0.288136 
0.0538922 0.608284 0.764706 0.95122 -0.248555 0.821229 0.848604 -0.0030581 0.239407 
0.173653 0.866272 0.682353 0.95122 0.017341 0.704709 -0.0210016 -0.195719 0.150424 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5

测试结果与真值

333.986169852 334.0
360.84170859 360.0
342.658750421 343.0
329.591753015 328.0
374.247432336 374.0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 1
  • 2
  • 3
  • 4
  • 5

更多详细功能介绍请看这里

调参与结果对比

误差度量采用预测值与真值的误差占真值的百分比的均值

方法参数平均误差百分比
svm 最佳参数 1.60452%
gdbt n_estimators=100,max_depth=3 2.29247%
gdbt n_estimators=1000,max_depth=3 1.23875%
gdbt n_estimators=1000,max_depth=5 1.14202%
gdbt n_estimators=1000,max_depth=7 1.02505%

可以看出n_estimatorsmax_depth 与gbdt的表达能力相关度很高 
同时gbdt相对svm效果更优

 
posted @ 2016-12-22 14:38  Django's blog  阅读(1492)  评论(0编辑  收藏  举报