机器学习算法原理实现——cart决策树:分类&回归
cart决策树示例:
本文目标,仿照sklearn写一个cart树,但是仅仅使用max_depth作为剪枝依据。
我们本次实现cart分类,因此用到gini指数:
为了帮助理解:
好了,理解了基尼指数。我们看下cart树的构建步骤:
注意还有几个细节:
cart树每个treenode存储了哪些数据?
在CART决策树中,每个节点(TreeNode)通常存储以下数据:
-
划分特征:这是用于根据某种条件划分数据集的特征。例如,如果一个节点用"年龄 > 30"作为分割条件,那么"年龄"就是这个节点的划分特征。
-
划分阈值:与划分特征配合使用,定义了数据应如何分割。在上面的例子中,阈值是30。
-
左子节点:满足划分条件的数据子集的节点。例如,在上面的"年龄 > 30"例子中,大于30岁的数据会被划分到左子节点。
-
右子节点:不满足划分条件的数据子集的节点。在上面的例子中,30岁及以下的数据会被划分到右子节点。
-
类标签:只在叶节点中有效。表示该节点所代表的数据子集中最常见的类别。当新数据通过决策树进行预测时,最终到达的叶节点的类标签就是其预测结果。
-
数据子集:节点当前代表的数据子集。在许多实际实现中,为了节省内存,节点可能不直接存储数据子集,而是存储数据索引或其他引用。
-
基尼不纯度或其他不纯度指标:代表当前数据子集的不纯度。在构建树的过程中,这个指标用于判断是否应该继续划分当前节点。
-
其他可选信息:如节点深度、父节点引用、数据点的数量等。
这些数据允许决策树在训练过程中进行递归分割,以及在预测过程中导航通过树结构。
好了,实现代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import numpy as np class TreeNode: def __init__( self , gini, num_samples, num_samples_per_class, predicted_class): self .gini = gini self .num_samples = num_samples self .num_samples_per_class = num_samples_per_class self .predicted_class = predicted_class self .feature_index = 0 self .threshold = 0 self .left = None self .right = None def gini(y): m = len (y) return 1.0 - sum ([(np. sum (y = = c) / m) * * 2 for c in np.unique(y)]) def grow_tree(X, y, depth = 0 , max_depth = None ): classes = np.unique(y) num_samples_per_class = [np. sum (y = = c) for c in classes] predicted_class = classes[np.argmax(num_samples_per_class)] node = TreeNode( gini = gini(y), num_samples = len (y), num_samples_per_class = num_samples_per_class, predicted_class = predicted_class, ) if depth < max_depth: idx, thr = best_split(X, y) if idx is not None : indices_left = X[:, idx] < thr X_left, y_left = X[indices_left], y[indices_left] X_right, y_right = X[~indices_left], y[~indices_left] node.feature_index = idx node.threshold = thr node.left = grow_tree(X_left, y_left, depth + 1 , max_depth) node.right = grow_tree(X_right, y_right, depth + 1 , max_depth) return node def best_split(X, y): """ 用numpy实现best_split,见下,可以先看不用numpy的实现 """ n_samples, n_features = X.shape if len (np.unique(y)) = = 1 : return None , None best = {} min_gini = float ( 'inf' ) for feature_idx in range (n_features): thresholds = np.unique(X[:, feature_idx]) for threshold in thresholds: left_mask = X[:, feature_idx] < threshold right_mask = ~left_mask gini_left = gini(y[left_mask]) gini_right = gini(y[right_mask]) weighted_gini = len (y[left_mask]) / n_samples * gini_left + len (y[right_mask]) / n_samples * gini_right if weighted_gini < min_gini: best = { 'feature_index' : feature_idx, 'threshold' : threshold, 'left_labels' : y[left_mask], 'right_labels' : y[right_mask], 'gini' : weighted_gini } min_gini = weighted_gini return best[ 'feature_index' ], best[ 'threshold' ] def best_split2(X, y): """ 不用numpy实现best_split """ n_samples, n_features = len (X), len (X[ 0 ]) # 如果样本中只有一种输出标签或样本为空,则返回None if len ( set (y)) = = 1 : return None , None # 初始化最佳分割的信息 best = {} min_gini = float ( 'inf' ) # 遍历每个特征 for feature_idx in range (n_features): # 获取当前特征的所有唯一值,并排序 unique_values = sorted ( set (row[feature_idx] for row in X)) # 遍历每个唯一值,考虑将其作为分割阈值 for value in unique_values: left_y, right_y = [], [] # 对于每个样本,根据其特征值与阈值的关系分到左子集或右子集 for i, row in enumerate (X): if row[feature_idx] < value: left_y.append(y[i]) else : right_y.append(y[i]) # 计算左子集和右子集的基尼指数 gini_left = 1.0 - sum ([(left_y.count(label) / len (left_y)) * * 2 for label in set (left_y)]) gini_right = 1.0 - sum ([(right_y.count(label) / len (right_y)) * * 2 for label in set (right_y)]) # 计算加权基尼指数 weighted_gini = len (left_y) / len (y) * gini_left + len (right_y) / len (y) * gini_right # 如果当前基尼值小于已知的最小基尼值,更新最佳分割 if weighted_gini < min_gini: best = { 'feature_index' : feature_idx, 'threshold' : value, 'left_labels' : left_y, 'right_labels' : right_y, 'gini' : weighted_gini } min_gini = weighted_gini return best[ 'feature_index' ], best[ 'threshold' ] def predict_tree(node, X): if node.left is None and node.right is None : return node.predicted_class if X[node.feature_index] < node.threshold: return predict_tree(node.left, X) else : return predict_tree(node.right, X) def predict_tree2(node, X): if node.left is None and node.right is None : return node.predicted_class if X[node.feature_index] < node.threshold: return predict_tree(node.left, X) else : return predict_tree(node.right, X) class CARTClassifier: def __init__( self , max_depth = None ): self .max_depth = max_depth def fit( self , X, y): self .tree_ = grow_tree(X, y, max_depth = self .max_depth) def predict( self , X): return [predict_tree( self .tree_, x) for x in X] # 使用示例 if __name__ = = "__main__" : """ # 好好理解下这个分割的函数 X = np.array([[2.5], [3.5], [1], [1.5], [2], [3], [0]]) y = np.array([1, 1, 0, 0, 1, 0, 2]) best_idx, best_thr = best_split(X, y) """ from sklearn.datasets import load_iris data = load_iris() X, y = data.data, data.target clf = CARTClassifier(max_depth = 4 ) clf.fit(X, y) preds = clf.predict(X) accuracy = sum (preds = = y) / len (y) print (f "Accuracy: {accuracy:.4f}" ) from sklearn.tree import DecisionTreeClassifier # 创建分类树实例 clf = DecisionTreeClassifier(max_depth = 4 ) # 分类树训练 clf.fit(X, y) preds = clf.predict(X) accuracy = sum (preds = = y) / len (y) print (f "sklearn Accuracy: {accuracy:.4f}" ) |
输出:
Accuracy: 0.9933
sklearn Accuracy: 0.9933
我们再来实现一个cart回归树吧!
要将分类CART树修改为回归CART树,我们需要做以下几个主要的修改:
1. 在TreeNode类中,我们需要将predicted_class改为predicted_value,因为在回归问题中,我们预测的是一个连续值,而不是类别。
2. 在gini函数中,我们需要将基尼指数的计算方式改为计算均方误差(MSE)。在回归问题中,我们通常使用MSE作为节点不纯度的度量。
3. 在grow_tree函数中,我们需要将predicted_class改为predicted_value,并将其计算方式改为计算目标值的平均值。
4. 在best_split函数中,我们需要将基尼指数的计算方式改为计算MSE。
5. 在CARTClassifier类中,我们需要将类名改为CARTRegressor,并将predict函数中的类别预测改为值预测。
1. 在TreeNode类中,我们需要将predicted_class改为predicted_value,因为在回归问题中,我们预测的是一个连续值,而不是类别。
2. 在gini函数中,我们需要将基尼指数的计算方式改为计算均方误差(MSE)。在回归问题中,我们通常使用MSE作为节点不纯度的度量。
3. 在grow_tree函数中,我们需要将predicted_class改为predicted_value,并将其计算方式改为计算目标值的平均值。
4. 在best_split函数中,我们需要将基尼指数的计算方式改为计算MSE。
5. 在CARTClassifier类中,我们需要将类名改为CARTRegressor,并将predict函数中的类别预测改为值预测。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error from sklearn.tree import DecisionTreeRegressor from sklearn.datasets import load_diabetes # 加载波士顿房价数据集 data = load_diabetes() X, y = data.data, data.target # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2 , random_state = 42 ) # 创建CART回归树模型 regressor = DecisionTreeRegressor(max_depth = 4 ) # 训练模型 regressor.fit(X_train, y_train) # 预测测试集 y_pred = regressor.predict(X_test) # 计算均方误差 mse = mean_squared_error(y_test, y_pred) print (f "Mean Squared Error: {mse:.2f}" ) import numpy as np class TreeNode: def __init__( self , mse, num_samples, predicted_value): self .mse = mse self .num_samples = num_samples self .predicted_value = predicted_value self .feature_index = 0 self .threshold = 0 self .left = None self .right = None def mse(y): if len (y) = = 0 : return 0 return np.mean((y - np.mean(y)) * * 2 ) def grow_tree(X, y, depth = 0 , max_depth = None ): num_samples = len (y) predicted_value = np.mean(y) node = TreeNode( mse = mse(y), num_samples = num_samples, predicted_value = predicted_value, ) if depth < max_depth: idx, thr = best_split(X, y) if idx is not None : indices_left = X[:, idx] < thr X_left, y_left = X[indices_left], y[indices_left] X_right, y_right = X[~indices_left], y[~indices_left] node.feature_index = idx node.threshold = thr node.left = grow_tree(X_left, y_left, depth + 1 , max_depth) node.right = grow_tree(X_right, y_right, depth + 1 , max_depth) return node def best_split(X, y): n_samples, n_features = X.shape if n_samples < = 1 : return None , None best = {} min_mse = float ( 'inf' ) for feature_idx in range (n_features): thresholds = np.unique(X[:, feature_idx]) for threshold in thresholds: left_mask = X[:, feature_idx] < threshold right_mask = ~left_mask mse_left = mse(y[left_mask]) mse_right = mse(y[right_mask]) weighted_mse = len (y[left_mask]) / n_samples * mse_left + len (y[right_mask]) / n_samples * mse_right if weighted_mse < min_mse: best = { 'feature_index' : feature_idx, 'threshold' : threshold, 'left_values' : y[left_mask], 'right_values' : y[right_mask], 'mse' : weighted_mse } min_mse = weighted_mse return best[ 'feature_index' ], best[ 'threshold' ] def predict_tree(node, X): if node.left is None and node.right is None : return node.predicted_value if X[node.feature_index] < node.threshold: return predict_tree(node.left, X) else : return predict_tree(node.right, X) class CARTRegressor: def __init__( self , max_depth = None ): self .max_depth = max_depth def fit( self , X, y): self .tree_ = grow_tree(X, y, max_depth = self .max_depth) def predict( self , X): return [predict_tree( self .tree_, x) for x in X] # 创建CART回归树模型 regressor = CARTRegressor(max_depth = 4 ) # 训练模型 regressor.fit(X_train, y_train) # 预测测试集 y_pred = regressor.predict(X_test) # 计算均方误差 mse = mean_squared_error(y_test, y_pred) print (f "Mean Squared Error: {mse:.2f}" ) |
输出结果:
Mean Squared Error: 3682.01
Mean Squared Error: 3682.01
可以看到,和sklearn的输出几乎无差别!
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2022-09-10 metersploit msf 常用命令
2022-09-10 windows 7 iso下载地址
2020-09-10 小样本学习文献
2019-09-10 联邦学习
2019-09-10 数据库索引数据结构总结——ART树就是前缀树