决策树算法理解和应用
决策树算法是一种监督式学习算法,它简单好用,易于解释,在金融科技,数字健康,教育服务,消费互联网等许多领域发挥着积极作用。决策树算法学习的结果,类似下图结构:
本文首先介绍决策树的原理,然后基于tidymodels框架设计和执行决策树算法以解决实际问题。
一、决策树算法原理
决策树算法的理解,可以参考下面的算法伪代码(来源:数据挖掘概念与技术)
决策树算法需要解决关键问题
1 如何选择特征做拆分?
主要采用这些度量方法
1)信息增益
最大化变量的信息增益,确定变量的拆分以及先后顺序
2)增益率
增益率用于优化信息增益偏向于具有变量值分布不一致所导致的问题。
3)Gini 指数
2 如何对树的结构进行裁剪?
目的:防止学习的模型过拟合(对训练集效果好,而测试集上效果不佳)
使用统计测量删除不可靠的分支或者有少量样本组成的分支。实际操作中,可以通过设置树生成的一些超参数来控制树的结构,比方说:
1)树的最大深度max_depth
2)树的最小划分样本数min_samples_split
3)数的叶子节点最小样本数min_samples_leaf
通过裁剪技术,可以让树更加简洁,容易理解,也可提提升模型的泛化性能。
决策树算法的优点:
-
简单可解释
-
可以处理各种数据
-
非参数模型
-
稳健
-
快速
决策树算法的缺点:
-
过度拟合问题
-
不稳定问题
-
偏差问题
-
优化问题
二、决策树算法应用案例
利用决策树算法预测Scooby Doo monsters是否真实?
第一步:数据理解与准备
options(warn = -1)
library(tidyverse)
# 数据获取
scooby_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-07-13/scoobydoo.csv")
scooby_raw %>%
filter(monster_amount > 0) %>%
count(monster_real)
第二步:从不同维度做洞察
时间维度
scooby_raw %>%
filter(monster_amount > 0) %>%
count(
year_aired = 10 * ((lubridate::year(date_aired) + 1) %/% 10),
monster_real
) %>%
mutate(year_aired = factor(year_aired)) %>%
ggplot(aes(year_aired, n, fill = monster_real)) +
geom_col(position = position_dodge(preserve = "single"), alpha = 0.8) +
labs(x = "Date aired", y = "Monsters per decade", fill = "Real monster?")
imdb评分维度
scooby_raw %>%
filter(monster_amount > 0) %>%
mutate(imdb = parse_number(imdb)) %>%
ggplot(aes(imdb, after_stat(density), fill = monster_real)) +
geom_histogram(position = "identity", alpha = 0.5) +
labs(x = "IMDB rating", y = "Density", fill = "Real monster?")
第三步:决策树模型构建
数据集划分
训练集,用于训练模型
测试集,用于评价模型性能
训练集中利用bootstraps策略用于做超参数选择和优化
library(tidymodels)
set.seed(123)
scooby_split <- scooby_raw %>%
mutate(
imdb = parse_number(imdb),
year_aired = lubridate::year(date_aired)
) %>%
filter(monster_amount > 0, !is.na(imdb)) %>%
mutate(
monster_real = case_when(
monster_real == "FALSE" ~ "fake",
TRUE ~ "real"
),
monster_real = factor(monster_real)
) %>%
select(year_aired, imdb, monster_real, title) %>%
initial_split(strata = monster_real)
scooby_train <- training(scooby_split)
scooby_test <- testing(scooby_split)
set.seed(234)
scooby_folds <- bootstraps(scooby_train, strata = monster_real)
scooby_folds
决策树模型设计
# 设计决策树模型
tree_spec <-
decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) %>%
set_mode("classification") %>%
set_engine("rpart")
tree_spec
tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 4)
tree_grid
doParallel::registerDoParallel()
set.seed(345)
tree_rs <-
tune_grid(
tree_spec,
monster_real ~ year_aired + imdb,
resamples = scooby_folds,
grid = tree_grid,
metrics = metric_set(accuracy, roc_auc, sensitivity, specificity)
)
tree_rs
第四步:模型性能评价
# 模型评估和理解
show_best(tree_rs)
# 超参数可视化
autoplot(tree_rs) + theme_light(base_family = "IBMPlexSans")
# 基于所关注的指标选择最佳模型的超参数
simpler_tree <- select_by_one_std_err(tree_rs,
-cost_complexity,
metric = "roc_auc"
)
# 根据最佳参数重构模型
final_tree <- finalize_model(tree_spec, simpler_tree)
final_fit <- fit(final_tree, monster_real ~ year_aired + imdb, scooby_train)
final_rs <- last_fit(final_tree, monster_real ~ year_aired + imdb, scooby_split)
collect_metrics(final_rs)
第五步:模型结果可视化
# 决策树执行决策的可视化
library(parttree)
scooby_train %>%
ggplot(aes(imdb, year_aired)) +
geom_parttree(data = final_fit, aes(fill = monster_real), alpha = 0.2) +
geom_jitter(alpha = 0.7, width = 0.05, height = 0.2, aes(color = monster_real))
参考资料:
1 Understanding the Mathematics Behind Decision Trees | by Nikita Sharma | Heartbeat (fritz.ai)
2 https://juliasilge.com/blog/scooby-doo/
3 https://github.com/grantmcdermott/parttree
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 开源Multi-agent AI智能体框架aevatar.ai,欢迎大家贡献代码
· Manus重磅发布:全球首款通用AI代理技术深度解析与实战指南