使用随机森林计算特征重要度

转载:https://blog.csdn.net/IqqIqqIqqIqq/article/details/78857411

 

1 基于sklearn的实现

from sklearn.datasets import load_boston
from sklearn.ensemble import RandomForestRegressor
import numpy as np
#Load boston housing dataset as an example
boston = load_boston()
X = boston["data"]
Y = boston["target"]
names = boston["feature_names"]
rf = RandomForestRegressor()
rf.fit(X, Y)
print "Features sorted by their score:"
print sorted(zip(map(lambda x: round(x, 4), rf.feature_importances_), names), 
             reverse=True)

  输出为

Features sorted by their score:
[(0.5298, 'LSTAT'), (0.4116, 'RM'), (0.0252, 'DIS'), (0.0172, 'CRIM'), (0.0065, 'NOX'), (0.0035, 'PTRATIO'), (0.0021, 'TAX'), (0.0017, 'AGE'), (0.0012, 'B'), (0.0008, 'INDUS'), (0.0004, 'RAD'), (0.0001, 'CHAS'), (0.0, 'ZN')]

  基于不纯度对模型进行排序有几点需要注意: 
(1)基于不纯度降低的特征选择将会偏向于选择那些具有较多类别的变量(bias)。 
(2)当存在相关特征时,一个特征被选择后,与其相关的其他特征的重要度则会变得很低,因为他们可以减少的不纯度已经被前面的特征移除了。

 

2 准确率降低的均值 
这种方法是直接测量每种特征对模型预测准确率的影响,基本思想是重新排列某一列特征值的顺序,观测降低了多少模型的准确率。对于不重要的特征,这种方法对模型准确率的影响很小,但是对于重要特征却会极大降低模型的准确率。 
下面是这种方法的示例:

from sklearn.cross_validation import ShuffleSplit
from sklearn.metrics import r2_score
from collections import defaultdict

X = boston["data"]
Y = boston["target"]

rf = RandomForestRegressor()
scores = defaultdict(list)

#crossvalidate the scores on a number of different random splits of the data
for train_idx, test_idx in ShuffleSplit(len(X), 100, .3):
    X_train, X_test = X[train_idx], X[test_idx]
    Y_train, Y_test = Y[train_idx], Y[test_idx]
    r = rf.fit(X_train, Y_train)
    acc = r2_score(Y_test, rf.predict(X_test))
    for i in range(X.shape[1]):
        X_t = X_test.copy()
        np.random.shuffle(X_t[:, i])
        shuff_acc = r2_score(Y_test, rf.predict(X_t))
        scores[names[i]].append((acc-shuff_acc)/acc)
print "Features sorted by their score:"
print sorted([(round(np.mean(score), 4), feat) for
              feat, score in scores.items()], reverse=True)

  输出:

Features sorted by their score:
[(0.7276, 'LSTAT'), (0.5675, 'RM'), (0.0867, 'DIS'), (0.0407, 'NOX'), (0.0351, 'CRIM'), (0.0233, 'PTRATIO'), (0.0168, 'TAX'), (0.0122, 'AGE'), (0.005, 'B'), (0.0048, 'INDUS'), (0.0043, 'RAD'), (0.0004, 'ZN'), (0.0001, 'CHAS')]

 

posted @ 2018-08-22 21:40  迷茫的计算机呆  阅读(4991)  评论(0编辑  收藏  举报