如何统计随机森林节点数

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=4,
                            n_informative=2, n_redundant=0,
                            random_state=0, shuffle=False)
rf = RandomForestClassifier(max_depth=2, random_state=0)
rf.fit(X, y)
print(rf.predict([[0, 0, 0, 0]]))
print(rf.get_params())


# Import tools needed for visualization
from sklearn.tree import export_graphviz
import pydot
# Pull out one tree from the forest
tree = rf.estimators_[5]
# Import tools needed for visualization
from sklearn.tree import export_graphviz
import pydot
# Pull out one tree from the forest
tree = rf.estimators_[5]
feature_list = ["feature1", "feature2", "feature3", "feature4"]
# Export the image to a dot file
export_graphviz(tree, out_file = 'tree.dot', feature_names = feature_list, rounded = True, precision = 1)
# Use dot file to create a graph
(graph, ) = pydot.graph_from_dot_file('tree.dot')
# Write graph to a png file
graph.write_png('tree.png')

  

posted @ 2020-09-23 20:46  bonelee  阅读(438)  评论(1编辑  收藏  举报