XGBoost 输出特征重要性以及筛选特征
1.输出XGBoost特征的重要性
from matplotlib import pyplot pyplot.bar(range(len(model_XGB.feature_importances_)), model_XGB.feature_importances_) pyplot.show()
from matplotlib import pyplot pyplot.bar(range(len(model_XGB.feature_importances_)), model_XGB.feature_importances_) pyplot.show()
也可以使用XGBoost内置的特征重要性绘图函数
# plot feature importance using built-in function from xgboost import plot_importance plot_importance(model_XGB) pyplot.show()
# plot feature importance using built-in function from xgboost import plot_importance plot_importance(model_XGB) pyplot.show()
2.根据特征重要性筛选特征
from numpy import sort from sklearn.feature_selection import SelectFromModel # Fit model using each importance as a threshold thresholds = sort(model_XGB.feature_importances_) for thresh in thresholds: # select features using threshold selection = SelectFromModel(model_XGB, threshold=thresh, prefit=True) select_X_train = selection.transform(X_train) # train model selection_model = XGBClassifier() selection_model.fit(select_X_train, y_train) # eval model select_X_test = selection.transform(X_test) y_pred = selection_model.predict(select_X_test) predictions = [round(value) for value in y_pred] accuracy = accuracy_score(y_test, predictions) print("Thresh=%.3f, n=%d, Accuracy: %.2f%%" % (thresh, select_X_train.shape[1], accuracy*100.0))
from numpy import sort from sklearn.feature_selection import SelectFromModel # Fit model using each importance as a threshold thresholds = sort(model_XGB.feature_importances_) for thresh in thresholds: # select features using threshold selection = SelectFromModel(model_XGB, threshold=thresh, prefit=True) select_X_train = selection.transform(X_train) # train model selection_model = XGBClassifier() selection_model.fit(select_X_train, y_train) # eval model select_X_test = selection.transform(X_test) y_pred = selection_model.predict(select_X_test) predictions = [round(value) for value in y_pred] accuracy = accuracy_score(y_test, predictions) print("Thresh=%.3f, n=%d, Accuracy: %.2f%%" % (thresh, select_X_train.shape[1], accuracy*100.0))
参考:https://blog.csdn.net/u011630575/article/details/79423162