深度森林gcForest模型
Python实现gcForest模型
1.介绍
gcForest v1.1.1是gcForest的一个官方托管在GitHub上的版本,是由Ji Feng(Deep Forest的paper的作者之一)维护和开发,该版本支持Python3.5,且有类似于Scikit-Learn的API接口风格,在该项目中提供了一些调用例子,目前支持的基分类器有RandomForestClassifier,XGBClassifer,ExtraTreesClassifier,LogisticRegression,SGDClassifier如果采用XGBoost的基分类器还可以使用GPU,如果想增加其他基分类器,可以在模块中的lib/gcforest/estimators/__init__.py
中添加,使用该模块需要依赖安装如下模块:
- argparse
- joblib
- keras
- psutil
- scikit-learn>=0.18.1
- scipy
- simplejson
- tensorflow
- xgboost
2.API调用样例
这里先列出gcForest提供的API接口:
-
fit_tranform(X_train,y_train) 是gcForest模型最后一层每个估计器预测的概率concatenated的结果
-
fit_transform(X_train,y_train,X_test=x_test,y_test=y_test) 测试数据的准确率在训练的过程中也会被记录下来
-
set_keep_model_mem(False) 如果你的缓存不够,把该参数设置成False(默认为True),如果设置成False,你需要使用fit_transform(X_train,y_train,X_test=x_test,y_test=y_test)来评估你的模型
-
predict(X_test) # 模型预测
-
transform(X_test)
最简单的调用gcForest的方式如下:
# 导入必要的模块
from gcforest.gcforest import GCForest
# 初始化一个gcForest对象
gc = GCForest(config) # config是一个字典结构
# gcForest模型最后一层每个估计器预测的概率concatenated的结果
X_train_enc = gc.fit_transform(X_train,y_train)
# 测试集的预测
y_pred = gc.predict(X_test)
下面我们使用MNIST数据集来演示gcForest的使用及代码的详细说明:
# 导入必要的模块
import argparse # 命令行参数调用模块
import numpy as np
import sys
from keras.datasets import mnist # MNIST数据集
import pickle
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
sys.path.insert(0, "lib")
from gcforest.gcforest import GCForest
from gcforest.utils.config_utils import load_json
def parse_args():
'''
解析终端命令行参数(model)
'''
parser = argparse.ArgumentParser()
parser.add_argument("--model", dest="model", type=str, default=None,
help="gcfoest Net Model File")
args = parser.parse_args()
return args
def get_toy_config():
'''
生成级联结构的相关结构
'''
config = {}
ca_config = {}
ca_config["random_state"] = 0
ca_config["max_layers"] = 100
ca_config["early_stopping_rounds"] = 3
ca_config["n_classes"] = 10
ca_config["estimators"] = []
ca_config["estimators"].append(
{"n_folds": 5, "type": "XGBClassifier", "n_estimators": 10,
"max_depth": 5,"objective": "multi:softprob", "silent":
True, "nthread": -1, "learning_rate": 0.1} )
ca_config["estimators"].append({"n_folds": 5, "type": "RandomForestClassifier",
"n_estimators": 10, "max_depth": None, "n_jobs": -1})
ca_config["estimators"].append({"n_folds": 5, "type": "ExtraTreesClassifier",
"n_estimators": 10, "max_depth": None, "n_jobs": -1})
ca_config["estimators"].append({"n_folds": 5, "type": "LogisticRegression"})
config["cascade"] =