R语言学习笔记-单一决策树
决策树比较简单明晰,但存在不稳定的风险,数据的微小变化会导致最佳决策树结构的巨大变化,且决策树可能会变得比较复杂。
其算法原理参见https://zhuanlan.zhihu.com/p/148010749。笔记中主要以R语言中iris数据集描述实现步骤。
1 | data ( "iris" ) <br> #导入iris数据集 |
1 2 | set.seed (1926) #设置种子,便于复现操作结果 +1S |
之后需要将数据分为两部分,训练集与测试集,可以用多种写法实现。这些写法的本质上都是sample函数
方法一:
1 2 3 | train.data <- sample ( nrow (iris),0.7* nrow (iris),replace = F) train <-iris[train.data,] test <-iris[-train.data,] |
写法二:
1 2 3 4 5 6 | formula <- sample (2, nrow (iris), replace= TRUE , prob= c (0.7, 0.3) ) train <- iris[formula==1,] test <- iris[formula==2,] |
写法三:
1 2 3 4 | smple.size <- floor (0.7* nrow (data)) ) train.ind <- sample ( seq_len ( nrow (data)), smple.size) train <- data[train.ind, ] test <- data[-train.ind, ] |
写法四:
1 2 3 | rank_num <- sample (1:150,105) train <- iris[rank_num,] test <- iris[-rank_num,] |
接下来进行单一决策树分析,常用的包有tree,rpart,party等。
package "party"
1 2 3 4 5 6 7 8 | library (party) mdna.tree <- ctree (Species ~ Sepal.Length+Sepal.Width+Petal.Length+Petal.Width, data = train) mdna.tree #可以看看具体的分析 plot (mdna.tree, type = "simple" ,main = "mdna的简单决策树" ) #也可以自己画图,上面的是简单树装图 plot (mdna.tree,main = "mdna的全决策树" ) #全面树状图 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | Conditional inference tree with 4 terminal nodes Response: Species Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width Number of observations: 105 1) Petal.Length <= 1.9; criterion = 1, statistic = 98.207 2)* weights = 35 1) Petal.Length > 1.9 3) Petal.Width <= 1.7; criterion = 1, statistic = 50.335 4) Petal.Length <= 4.6; criterion = 0.984, statistic = 8.319 5)* weights = 33 4) Petal.Length > 4.6 6)* weights = 8 3) Petal.Width > 1.7 7)* weights = 29 |
package "rpart"
1 2 3 4 5 | library ( 'rpart' ) library ( 'rpart.plot' ) model.2<- rpart (formula =Species~.,data=train ,method= 'class' ) model.2 rpart.plot (model.2) |
1 2 3 4 5 6 7 8 9 10 | n= 105 node), split, n, loss, yval, (yprob) * denotes terminal node 1) root 105 67 versicolor (0.33333333 0.36190476 0.30476190) 2) Petal.Length< 2.6 35 0 setosa (1.00000000 0.00000000 0.00000000) * 3) Petal.Length>=2.6 70 32 versicolor (0.00000000 0.54285714 0.45714286) 6) Petal.Width< 1.75 41 3 versicolor (0.00000000 0.92682927 0.07317073) * 7) Petal.Width>=1.75 29 0 virginica (0.00000000 0.00000000 1.00000000) * |
rpart包提供了复杂度损失修剪的修剪方法,printcp会告诉分裂到每一层,cp是多少,平均相对误差是多少
1 | printcp (model.2) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | Classification tree: rpart (formula = Species ~ ., data = train, method = "class" ) Variables actually used in tree construction: [1] Petal.Length Petal.Width Root node error: 67/105 = 0.6381 n= 105 CP nsplit rel error xerror xstd 1 0.52239 0 1.000000 1.268657 0.060056 2 0.43284 1 0.477612 0.582090 0.073898 3 0.01000 2 0.044776 0.074627 0.032570 |
#一般使用1-SE法则选出最优cp值:找到xerror最小的行,得到误差阈值为该行的xerror+xstd
##找到所有xerror小于这个阈值的行,取其中最大值的上限为prune的阈值
###根据我们的结果,来看最小的交叉验证误差为0.074,刚好是最后一个节点,不需要剪枝
####需要剪枝的案例https://danzhuibing.github.io/r_decision_tree.html
剪枝的代码
1 | # model.prune <- prune(cfit, cp=0.03) <br>#cp值为示例 |
换个颜色
1 | rpart.plot (model.2, box.col= c ( "pink" , "purple" , "lightblue" )) |
这个图也可以用DMwR绘制
1 2 | library (DMwR) prettyTree (model.2,main= 'tree of mdna with DMwR' ) |
package 'tree'
1 2 3 4 5 6 7 8 9 10 | library (tree) model.3<- tree (Species ~ Sepal.Width + Sepal.Length + Petal.Length + Petal.Width, data = iris ) summary (model.3) plot (model.3) text (model.3) |
结果:
1 2 3 4 5 6 7 8 | Classification tree: tree (formula = Species ~ Sepal.Width + Sepal.Length + Petal.Length + Petal.Width, data = iris) Variables actually used in tree construction: [1] "Petal.Length" "Petal.Width" "Sepal.Length" Number of terminal nodes: 6 Residual mean deviance: 0.1253 = 18.05 / 144 Misclassification error rate: 0.02667 = 4 / 150 |
最后使用测试集进行检验,一般使用predict函数
1 2 3 | predict <- predict (model.2,newdata=test,type= 'class' ) result.2<- table (test$Species,predict) sum ( diag (result.2))/ sum (result.2) |
结果是
[1] 0.9333333
即93.3%的准确率
这个矩阵的样子如下,对角线上的值代表预测正确的值,用对角线除以总数,就可以得到正确率了。
1 2 3 4 5 6 | table (test$Species,predict) predict setosa versicolor virginica setosa 15 0 0 versicolor 0 11 1 virginica 0 2 16 |
【一般用不到】如果列名不正常,可以使用如下代码apply每行的列名为最大值对应列名
1 | a <- predict (model.2,newdata=test,type= 'class' )b <- apply (a, 1, function (t) colnames (a)[ which.max (t)]) |
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· Linux系列:如何用 C#调用 C方法造成内存泄露
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· SQL Server 2025 AI相关能力初探
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库