我生命中最美好的一天——预测钻石价格的回归分析

我生命中最美好的一天——预测钻石价格的回归分析

希瑟·祖雷尔

2022–09–08

我一生中最美好的一天是我的伴侣向我求婚的时候。他很好奇他是否为他买给我的钻石多付了钱(他认为他得到了很多!)——让我们看看。

我将使用钻石数据集来训练回归模型,以根据钻石的物理特性(包括尺寸/克拉、净度、颜色等)来预测钻石的价格。

第一步是一些 探索性数据分析 看看我们在训练模型之前是否需要做任何事情来清理数据。

PS 还有一些额外的分析、绘图等包含在 钻石.r 文件在 github , 没有出现在本文档中。

准备和探索性数据分析

 # 设置  
 图书馆('tidyverse')  
 图书馆('GGally')  
  
 # 进口  
 数据(钻石)  
  
 # 探索性  
 摘要(钻石) ggpairs(钻石) ##克拉切割颜色净度深度  
 ## 分钟。 :0.2000 公平 : 1610 D: 6775 SI1 :13065 Min. :43.00  
 ## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258 1st Qu.:61.00  
 ## 中位数:0.7000 非常好:12082 F:9542 SI2:9194 中位数:61.80  
 ## 平均值:0.7979 溢价:13791 G:11292 VS1:8171 平均值:61.75  
 ## 第三排:1.0400 理想:21551 H:8304 VVS2:5066 第三排:62.50  
 ## 最大限度。 :5.0100 I: 5422 VVS1 : 3655 最大。 :79.00  
 ## J:2808(其他):2531  
 ## 表价格 xy  
 ## 分钟。 :43.00 分钟。 : 326 分钟。 : 0.000 分钟。 : 0.000  
 ## 第一区:56.00 第一区:950 第一区:4.710 第一区:4.720  
 ## 中位数:57.00 中位数:2401 中位数:5.700 中位数:5.710  
 ## 平均值:57.46 平均值:3933 平均值:5.731 平均值:5.735  
 ## 第三区:59.00 第三区:5324 第三区:6.540 第三区:6.540  
 ## 最大限度。 :95.00 最大。 :18823 最大。 :10.740 最大。 :58.900  
 ##  
 ## z  
 ## 分钟。 : 0.000  
 ## 第一曲:2.910  
 ## 中位数:3.530  
 ## 平均值:3.539  
 ## 第三曲:4.040  
 ## 最大限度。 :31.800  
 ##

所以 x、y 和 z 值对应于以毫米为单位的钻石的物理尺寸——这些与克拉值密切相关,克拉值是衡量钻石重量的标准——这是有道理的,我们真的不需要包括所有这些。一些 ML 模型对包含彼此强相关的多个特征很敏感——这可能导致过度拟合——因此我们将在训练模型之前删除 x、y 和 z 变量。

目标变量的分布(美元价格)

QQ 图(或分位数-分位数图)用于快速、直观地识别单个变量的分布。

 qq_diamonds ** <-** qqnorm((钻石 **$** 价格),主要 **=** "普通QQ价格图");qqline((钻石 **$** 价格))

_# 嗯_  
  
 qq_log_diamonds ** <-** qq规范( **日志** (钻石 **$** 价格),主要 **=** "对数价格的普通QQ图");qqline( **日志** (钻石 **$** 价格))

_# 哦,这更合适_ hist_norm ** <-** ggplot(钻石,AES( **日志** (价格))) **+**   
 geom_histogram(aes(y **=** ..密度..),颜色 **=** “黑色”,填充 **=**   
 '淡蓝色',垃圾箱 **=** 50) **+**   
 stat_function(有趣 **=** dnorm, args **=** **列表** (意思是 **=**   
 意思是( **日志** (钻石 **$** 价格)),标准差 **=** 标准差( **日志** (钻石 **$** 价格))))  
 hist_norm

根据 QQ 图和直方图,价格的对数似乎遵循双峰或多峰分布。让我们尝试另外几个情节来看看。

 小提琴 ** <-** ggplot(钻石,aes(x **=** 颜色,y **=** **日志** (价格),填写 **=**   
 颜色))  
 小提琴 **+**   
 geom_violin() **+**   
 scale_y_log10() **+**   
 facet_grid(清晰度 **~** 切)

 克拉 ** <-** ggplot(数据 **=** 钻石,AES(X **=** 克拉,y **=** **日志** (价格),  
 颜色 **=** 颜色))  
 克拉 **+**   
 stat_ecdf() **+**   
 facet_grid(清晰度 **~** 切)

是的,这绝对看起来像一个多峰分布,分布中有多个峰,对应于将钻石的克拉从 0.99 增加到 1 克拉、从 1.99 到 2 克拉等。并且在 1/2 克拉左右跳跃更小。

数值变量的方差

我要检查的另一件事是数值变量的方差。如果任何变量的方差与其他变量的差异 >= 1 个数量级,我们将对这些值进行标准化。如果一个变量的方差远大于其他变量,则可能会过分强调这些变量在训练模型中的重要性。

克拉 变量小于 1 个数量级 桌子 变量(因此接近 1 OOM 小于 深度 ),所以我们将继续标准化表格和深度。但是,这应该在数据集被拆分为训练和测试数据集之后发生。

 钻石 **% >%** summarise_if(is.numeric, **列表** (意思是 **=** 意思是,var **=** 曾是)) **% >%** t() ## [,1]  
 ## 克拉平均 7.979397e-01  
 ## depth_mean 6.174940e+01  
 ## table_mean 5.745718e+01  
 ## price_mean 3.932800e+03  
 ## x_mean 5.731157e+00  
 ## y_mean 5.734526e+00  
 ## z_mean 3.538734e+00  
 ## carat_var 2.246867e-01  
 ## depth_var 2.052404e+00  
 ## table_var 4.992948e+00  
 ## price_var 1.591563e+07  
 ## x_var 1.258347e+00  
 ## y_var 1.304472e+00  
 ## z_var 4.980109e-01

克拉 变量小于 1 个数量级 桌子 变量(因此接近 1 OOM 小于 深度 ),所以我们将继续标准化表格和深度。然而,这应该在数据集被分成训练和测试数据集之后发生。

我想我有足够的信息进行下一步……

数据清洗

我们将删除一些彼此高度相关的变量,留下一个变量来捕获其他 3 个变量中包含的信息。

我们要将价格转换为价格的对数。由于该数据集来自 2017 年,并且我试图预测 2021 年购买的钻石的价值,我们还将根据通货膨胀进行调整(约 10.55%)。

训练模型之前的另一个重要考虑因素是处理分类数据。通常,这些将被转换为“虚拟变量”或单热编码。这适用于没有自然排名或类别顺序的情况。在这里,切工、净度和颜色都有一个自然的顺序。例如,“良好”切工的钻石比“一般”切工的钻石更好。如果您从 r ( 数据(钻石) ) 那么这些变量将已经是具有正确顺序的因素。但是,如果您下载了此数据集的 csv,则需要将它们从字符串转换为有序因子,因此我将在此处包含转换步骤(即使它不应该更改我的数据集中的任何内容 - 尽管看起来“颜色”变量的顺序相反,所以我也会修复它)。注意:此功能中的级别是从最差到最好分配的。

在进一步研究了台面和深度场之后,这些值是与钻石平均直径的比值。表 % 影响钻石的光性能(即它看起来有多闪亮)。深度 % 会影响钻石的亮度和火彩(我不是 100% 确定这意味着什么,但我们会看看它是否会影响价格)。

 钻石 ** <-** 钻石 **% >%**  
 变异(价格 **=** 价格 ***** 1.1055) **% >%**  
 变异(log_price **=** **日志** (价格)) **% >%**  
 选择( **-** 价格, **-** X, **-** 是的, **-** z) **% >%**  
 变异(切 **=** 因子(切割,水平 **=** **C** (“一般”、“好”、“非常好”、“高级”、“理想”),已订购 **=** **真的** ),  
 颜色 **=** 因子(颜色,级别 **=** **C** ('J','I','H','G','F',  
 'E', 'D'), 有序 **=** **真的** ),  
 明晰 **=** 因素(清晰度,水平 **=** **C** ('I1','SI2','SI1',  
 'VS2'、'VS1'、'VVS2'、'VVS1'、'IF'),已订购 **=** **真的** ))  
  
 _# 这应该表明这三个变量现在是有序因子。_  
 str(钻石) _# 这里有一个小技巧,可以让 R 输出所有可能的因子水平的顺序,而不仅仅是前几个:_  
 **分钟** (钻石 **$** 切)  
 **分钟** (钻石 **$** 颜色)  
 **分钟** (钻石 **$** 明晰) ## tibble [53,940 × 7] (S3: tbl_df/tbl/data.frame)  
 ## $ 克拉:数字 [1:53940] 0.23 0.21 0.23 0.29 0.31 0.24 0.24  
 ## $ cut : Ord.factor w/ 5 个级别 "Fair"<"Good"<..: 5 4 2 4 2  
 ## $ color : Ord.factor w/7 个级别 "J"<"I"<"H"<"G"<..: 6 6 6 2  
 ## $ 清晰度:Ord.factor w/8 个级别 "I1"<"SI2"<"SI1"<..: 2 3 5 4  
 ## $ 深度:数字 [1:53940] 61.5 59.8 56.9 62.4 63.3 62.8 62.3  
 ## $ table : num [1:53940] 55 61 65 58 58 57 57 55 61 61 ...  
 ## $ log_price: num [1:53940] 5.89 5.89 5.89 5.91 5.91 ... ## [1] 公平  
 ## 级别:一般 < 好 < 非常好 < 高级 < 理想 ## [1] Ĵ  
 ## 级别:J < I < H < G < F < E < D ## [1] I1  
 ## 级别:I1 < SI2 < SI1 < VS2 < VS1 < VVS2 < VVS1 < IF

看起来不错!

模型准备和训练

注意:我们还对拆分数据集后的数据进行标准化(缩放)以避免 数据泄露 .这意味着训练数据集的值会影响测试数据集的值,因为它们的值用于标准化步骤。在新数据上运行模型时,这可能会影响模型的性能。

现在数据集已分为测试数据集和训练数据集,我们将训练几个不同的模型,测试它们的性能,并使用最好的模型进行预测。

_# 准备_  
 库(caTools)  
 图书馆(tictoc)  
  
 set.seed(42)  
  
 tic.clearlog()  
  
 分裂 ** <-** sample.split(钻石 **$** log_price, SplitRatio **=** 0.8)  
 diamonds_train ** <-** 子集(钻石,拆分 **==** **真的** )  
 diamonds_test ** <-** 子集(钻石,拆分 **==** **错误的** )  
  
 diamonds_train ** <-** diamonds_train **% >%**   
 变异_at( **C** ('表','深度'), **~** (规模(。) **% >%** 作为向量))  
 diamonds_test ** <-** diamonds_test **% >%**   
 变异_at( **C** ('表','深度'), **~** (规模(。) **% >%** 作为向量))  
  
 一瞥(diamonds_test) _# 让我们看看标准化变量:  
_ 平均(钻石测试 **$** 桌子)  
 sd(钻石测试 **$** 桌子) ## 行:9,706  
 ## 列:7  
 ## $ 克拉<dbl>0.30, 0.23, 0.30, 0.23, 0.23, 0.32, 0.32, 0.24,  
 ## $ 切<ord>很好,很好,很好,很好,很好,很好  
 ## $ 颜色<ord>我,H,J,F,E,H,H,F,我,E,H,F,E,H,G,G,  
 ## $ 清晰度<ord>SI2,VS1,VS2,VS1,VS1,SI2,SI2,SI1,SI1,  
 ## $ 深度<dbl>1.0417048, -0.5231586, 0.2932918, -0.5911962,  
 ##$表<dbl>-0.6422915,-0.1949679,-0.1949679,-0.1949679,  
 ## $ log_price<dbl> 5.961084, 5.966766, 5.978034, 5.978034, ... ## [1] 8.708704e-16 ## [1] 1 _# 另一个(深度)看起来也很相似,您可以使用 diamonds.r 中的代码自行检查。_

现在数据集已分为测试数据集和训练数据集,我们将训练几个不同的模型,测试它们的性能,并使用最好的模型进行预测。

多元线性回归

 抽动('传销')  
 传销 ** <-** lm(log_price **~** ., diamonds_train)  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )  
 摘要(传销) ##  
 ## 称呼:  
 ## lm(公式 = log_price ~ ., data = diamonds_train)  
 ##  
 ## 残差:  
 ## 最小值 1Q 中值 3Q 最大值  
 ## -5.8529 -0.2258 0.0605 0.2531 1.5885  
 ##  
 ## 系数:  
 ## 估计标准。误差 t 值 Pr(>|t|)  
 ##(拦截)6.042146 0.004654 1298.384 < 2e-16 ***  
 ## 克拉 2.167970 0.003860 561.587 < 2e-16 ***  
 ## cut.L 0.065157 0.007584 8.592 < 2e-16 ***  
 ## cut.Q -0.009315 0.006074 -1.53​​4 0.1251  
 ## cut.C 0.030696 0.005215 5.886 3.98e-09 ***  
 ## 切割^4 0.006782 0.004174 1.625 0.1042  
 ## 颜色.L 0.510645 0.005811 87.878 < 2e-16 ***  
 ## 颜色.Q -0.159645 0.005279 -30.240 < 2e-16 ***  
 ## 颜色.C 0.004640 0.004939 0.940 0.3475  
 ## 颜色^4 0.039363 0.004538 8.674 < 2e-16 ***  
 ## 颜色^5 0.022161 0.004290 5.165 2.41e-07 ***  
 ## 颜色^6 0.004631 0.003903 1.187 0.2354  
 ## 清晰度.L 0.768912 0.010191 75.449 < 2e-16 ***  
 ## 清晰度.Q -0.366598 0.009537 -38.441 < 2e-16 ***  
 ## 清晰度.C 0.216408 0.008156 26.534 < 2e-16 ***  
 ## 清晰度^4 -0.063964 0.006507 -9.830 < 2e-16 ***  
 ## 清晰度^5 0.052507 0.005299 9.910 < 2e-16 ***  
 ## 清晰度^6 0.006066 0.004606 1.317 0.1879  
 ## 清晰度^7 0.007096 0.004069 1.744 0.0812 。  
 ## 深度 -0.001029 0.001921 -0.535 0.5923  
 ## 表 0.014223 0.002188 6.500 8.14e-11 ***  
 ## ---  
 ## 意义。代码:0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1''1  
 ##  
 ## 剩余标准误差:44213 自由度上的 0.3442  
 ## 多重 R 平方:0.8884,调整后 R 平方:0.8884  
 ## F 统计量:20 和 44213 DF 上的 1.76e+04,p 值:< 2.2e-16

多项式回归

 tic('聚')  
 聚 ** <-** lm(log_price **~** 聚(克拉,3) **+** 颜色 **+** 切 **+** 明晰 **+** 聚(表,3) **+** 聚(深度,3),diamonds_train)  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )  
 摘要(聚) ##  
 ## 称呼:  
 ## lm(formula = log_price ~ poly(carat, 3) + 颜色 + 切工 + 净度 +  
 ## poly(table, 3) + poly(depth, 3), data = diamonds_train)  
 ##  
 ## 残差:  
 ## 最小值 1Q 中值 3Q 最大值  
 ## -3.2508 -0.0859 -0.0011 0.0872 1.9011  
 ##  
 ## 系数:  
 ## 估计标准。误差 t 值 Pr(>|t|)  
 ##(拦截)7.855146 0.001432 5486.562 < 2e-16 ***  
 ## 聚(克拉,3)1 224.223650 0.153651 1459.304 < 2e-16 ***  
 ## 聚(克拉,3)2 -65.007798 0.136049 -477.825 < 2e-16 ***  
 ## 聚(克拉,3)3 20.466149 0.135187 151.391 < 2e-16 ***  
 ## 颜色.L 0.441946 0.002259 195.623 < 2e-16 ***  
 ## 颜色.Q -0.086252 0.002053 -42.007 < 2e-16 ***  
 ## 颜色.C 0.009905 0.001916 5.169 2.36e-07 ***  
 ## 颜色^4 0.011069 0.001761 6.284 3.32e-10 ***  
 ## 颜色^5 0.008375 0.001664 5.033 4.86e-07 ***  
 ## 颜色^6 -0.001586 0.001514 -1.047 0.294889  
 ## cut.L 0.091883 0.003636 25.273 < 2e-16 ***  
 ## cut.Q -0.009912 0.002696 -3.676 0.000237 ***  
 ## cut.C 0.012412 0.002088 5.944 2.80e-09 ***  
 ## 切割^4 -0.002513 0.001625 -1.546 0.122172  
 ## 清晰度.L 0.887564 0.003990 222.462 < 2e-16 ***  
 ## 清晰度.Q -0.244990 0.003728 -65.708 < 2e-16 ***  
 ## 清晰度.C 0.141636 0.003181 44.523 < 2e-16 ***  
 ## 清晰度^4 -0.062813 0.002532 -24.807 < 2e-16 ***  
 ## 清晰度^5 0.029161 0.002058 14.167 < 2e-16 ***  
 ## 清晰度^6 -0.003592 0.001787 -2.010 0.044466 *  
 ## 清晰度^7 0.029769 0.001579 18.854 < 2e-16 ***  
 ## 聚(表,3)1 -0.691116 0.179675 -3.846 0.000120 ***  
 ## poly(table, 3)2 -0.603061 0.140336 -4.297 1.73e-05 ***  
 ## 聚(表,3)3 0.591239 0.136248 4.339 1.43e-05 ***  
 ## 聚(深度,3)1 -1.288272 0.162940 -7.906 2.71e-15 ***  
 ## 聚(深度,3)2 -1.105807 0.168139 -6.577 4.86e-11 ***  
 ## 多边形(深度,3)3 -0.057543 0.134870 -0.427 0.669631  
 ## ---  
 ## 意义。代码:0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1''1  
 ##  
 ## 剩余标准误差:44207 自由度上为 0.1335  
 ## 多重 R 平方:0.9832,调整后的 R 平方:0.9832  
 ## F 统计量:26 和 44207 DF 上的 9.96e+04,p 值:< 2.2e-16

哇!多项式回归似乎更适合!

支持向量回归 (SVR)

SVR 不依赖于基础因变量和自变量的分布。也可用于构建非线性模型 内核='径向' 选项。我认为是这种情况,因为到目前为止线性模型的表现最差。

 抽动('svr')  
 图书馆(e1071)  
 支持 ** <-** svm(公式 **=** log_price **~** .,  
 数据 **=** 钻石火车,  
 类型 **=** 'eps 回归',  
 核心 **=** '径向')  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )

决策树回归

决策树使用一组 if-then-else 决策规则。树越深,决策规则越复杂,模型越适合。训练 DT 模型有时会导致无法很好地概括数据的过于复杂的树。这称为过拟合。

 抽动('树')  
 库(rpart)  
 树 ** <-** rpart(公式 **=** log_price **~** .,  
 数据 **=** 钻石火车,  
 方法 **=** '方差分析',  
 模型 **=** **真的** )  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )  
 树 ## n= 44234  
 ##  
 ## node), split, n, deviance, yval  
 ## * 表示终端节点  
 ##  
 ## 1) 根 44234 46934.3000 7.922656  
 ## 2) 克拉< 0.695 20143 4323.3280 6.963920  
 ## 4) 克拉< 0.455 13848 1212.0970 6.718800 *  
 ## 5) 克拉>=0.455 6295 448.8476 7.503143 *  
 ## 3) 克拉>=0.695 24091 8615.3050 8.724275  
 ## 6) 克拉< 0.995 7837 612.7115 8.100690 *  
 ## 7) 克拉>=0.995 16254 3485.7290 9.024943  
 ## 14) 克拉< 1.385 10379 1099.8590 8.772585  
 ## 28) 清晰度=I1,SI2,SI1 5582 243.9614 8.574533 *  
 ## 29) 清晰度=VS2,VS1,VVS2,VVS1,IF 4797 382.1680 9.003046 *  
 ## 15) 克拉>=1.385 5875 557.1727 9.470768 *

随机森林回归

与在此模型中使用任何单个决策树相比,这使决策树模型更进一步,并使用许多决策树来做出更好的预测。

 抽动('rf')  
 图书馆(随机森林)  
 射频 ** <-** 随机森林(log_price **~** .,  
 数据 **=** 钻石火车,  
 ntree **=** 500,  
 重要性 **=** **真的** )  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )  
 射频 ##  
 ## 称呼:  
 ## randomForest(formula = log_price ~ ., data = diamonds_train, ntree = 500, 重要性 = TRUE)  
 ## 随机森林类型:回归  
 ## 树的数量:500  
 ## 每次拆分尝试的变量数:2  
 ##  
 ## 均方残差:0.01159442  
 ## % Var 解释:98.91

XGBoost 回归

XGBoost 使用梯度提升决策树,是一个非常强大的模型,在各种应用程序中表现非常出色。它对影响其他一些模型性能的问题(例如多重共线性或数据规范化/标准化)也不那么敏感。

 抽动('xgb')  
 库(xgboost)  
 diamonds_train_xgb ** <-** diamonds_train **% >%**  
 mutate_if(is.factor, as.numeric)  
 diamonds_test_xgb ** <-** diamonds_test **% >%**  
 mutate_if(is.factor, as.numeric)  
  
 xgb ** <-** xgboost(数据 **=** as.matrix(diamonds_train_xgb[-7]), 标签 **=** diamonds_train_xgb **$** log_price, nrounds **=** 6166,详细 **=** 0)  
 _# rmse 在 6166 轮后停止下降_  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )

模型性能

现在,我们将使用我们训练的每个模型来预测测试数据集中钻石价格的对数,以确定哪个模型在它未见过的数据上表现最好。

_# 进行预测并比较模型性能_  
 tic('predict_all')  
 mlm_prev ** <-** 预测(传销,钻石测试)  
 poly_prev ** <-** 预测(聚,diamonds_test)  
 svr_pred ** <-** 预测(svr,diamonds_test)  
 tree_pred ** <-** 预测(树,diamonds_test)  
 rf_pred ** <-** 预测(射频,钻石测试)  
 xgb_pred ** <-** 预测(xgb,as.matrix(diamonds_test_xgb[-7]))  
 目录(日志 **=** **真的** , 安静的 **=** **真的** )  
  
 _# 计算残差(即预测与测试数据集的 log_price 有多大不同)_  
 xgb_resid ** <-** diamonds_test_xgb **$** log_price **-** xgb_pred  
 图书馆(建模师)  
 渣油 ** <-** diamonds_test **% >%**    
 spread_residuals(mlm, poly, svr, tree, rf) **% >%**  
 选择(传销,聚,支持率,树,射频) **% >%**  
 rename_with( **~** paste0(.x, '_resid')) **% >%**  
 cbind(xgb_resid)  
  
 预测 ** <-** diamonds_test **% >%**  
 选择(log_price) **% >%**  
 cbind(mlm_pred) **% >%**  
 cbind(poly_pred) **% >%**  
 cbind(svr_pred) **% >%**  
 cbind(tree_pred) **% >%**  
 cbind(rf_pred) **% >%**  
 cbind(xgb_pred) **% >%**  
 cbind(残渣) _# 这对以后的绘图很有用_  
  
 _# 计算 R 平方 - 这描述了模型解释了多少可变性 - 越接近 1,越好_  
  
 mean_log_price ** <-** 平均(钻石测试 **$** log_price)  
 tss **=**  **和** ((diamonds_test_xgb **$** log_price **-** mean_log_price) **^** 2)  
  
 正方形 ** <-** **功能** (x) {x ****** 2}  
 r2 ** <-** **功能** (x) {1 **-** X **/** tss}  
  
 r2_df ** <-** 渣油 **% >%**  
 mutate_all(正方形) **% >%**  
 总结_所有(总和) **% >%**  
 mutate_all(r2) **% >%**  
 收集(关键 **=** '模型',值 **=** 'r2') **% >%**  
 变异(模型 **=** str_replace(模型,'_resid',''))  
 r2_df ## 模型 r2  
 ## 1 传销 0.8842696  
 ## 2 聚 0.9803430  
 ## 3 服务器 0.9847216  
 ## 4 树 0.9114766  
 ## 5 射频 0.9870050  
 ## 6 xgb 0.9842275

可视化模型的性能

随机森林模型根据 R² 值表现最佳——这是衡量模型解释了数据集中多少可变性的指标,因此我们将主要关注这个模型以进行可视化。它等于 1 - RMSE(均方根误差,它描述了所有预测与真实值的差异程度) y_test 数据集)。

 图书馆(ggplot2)  
 r2_plot ** <-** ggplot(r2_df, aes(x **=** 模型,和 **=** r2,颜色 **=** 模型,填充 **=** 模型)) **+** geom_bar(统计 **=** '身份')  
 r2_plot **+** ggtitle('每个模型的 R 平方值') **+** coord_cartesian(顶部 **=** **C** (0.75, 1))

 样本 ** <-** 预测 **% >%**  
 slice_sample(n **=** 1000)  
 ggplot(样本,aes(x **=** **经验** (log_price), 是 **=** **经验** (rf_pred),  
 尺寸 **=** **腹肌** (rf_resid))) **+**  
 几何点(阿尔法 **=** 0.1) **+**   
 实验室(标题 **=** '以美元计的钻石的预测成本与实际成本',  
 X **=** “价格”和 **=** “预计价格”,尺寸 **=** '残差')

随机森林 模型在我们尝试的所有模型中表现最好。这并不奇怪,因为它是 集成方法 这意味着它使用多个模型之间的一致性来做出比任何模型自己做出的更好的预测。 XGBoost 是集成方法的另一个例子,也表现得非常好。在第二个图中,大小与残差的绝对值成正比,这意味着预测与实际值的差异程度。

特征重要性

哪些变量对预测钻石价格最重要?

 varImpPlot(rf)

x 轴上的值表示如果该变量不包含在模型中,预测误差会增加多少。正如预期的那样,克拉(或大小)是最重要的变量。尽管台面和深度应该会影响钻石的闪亮程度,但实际上它们对价格的影响并不大。

训练模型和做出预测需要多长时间?

_# 训练和预测时间_  
 时间日志 ** <-** tic.log(格式 **=** **真的** )  
 时间日志 ## [[1]]  
 ## [1] "mlm: 0.057 秒过去"  
 ##  
 ## [[2]]  
 ## [1] “多边形:经过 0.154 秒”  
 ##  
 ## [[3]]  
 ## [1] "svr: 194.602 秒过去"  
 ##  
 ## [[4]]  
 ## [1] “树:经过 0.248 秒”  
 ##  
 ## [[5]]  
 ## [1] "rf: 407.62 秒过去"  
 ##  
 ## [[6]]  
 ## [1] "xgb: 62.249 秒过去"  
 ##  
 ## [[7]]  
 ## [1] "predict_all: 8.219 秒过去"

因此,性能最好的模型训练时间最长。这并不奇怪,因为它们实际上包含许多模型。请记住,这不是一个大型数据集。我为工作训练的大多数模型都需要数小时和数小时(我大多将它们设置为通宵运行)。这可以通过 AWS 中非常强大的机器甚至机器集群来加速。

最初发表于 https://zeather709.github.io .

订阅: https://zeather.medium.com/ 订阅/

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明

本文链接:https://www.qanswer.top/24584/43201009

posted @ 2022-09-10 09:43  哈哈哈来了啊啊啊  阅读(468)  评论(0编辑  收藏  举报