机器学习-企业破产预测

企业破产预测

选题背景

企业破产是商品经济的必然产物.在社会主义商品经济条件下,企业破产也是一种客观存在的经济现象.新中国的第一部《企业破产法》已经诞生,它的实施必将促进企业努力改善经营管理,提高经济效益,保护债权人和债务人的合法权益.企业破产无疑会给社会经济造成一系列严重损失.因此,预防和控制企业破产具有重大的社会经济意义.一个有效的企业破产预测模型无论是对国家的物资,资金,劳工等部门的计划和政策,或是对投资者,企业管理者的经营管理决策都具有重大作用.

1.数据描述

1.1数据来源

数据取自台湾经济杂志 1999 年至 2009 年的数据。公司破产是根据台湾证券交易所的业务规则定义的。

1.2数据概述

数据包含了与公司是否破产有关的95个因素包括有息债务的成本、现金再投资比率、流动比率、严峻考验、利息支出/总收入、总负债/权益比、负债/总资产、有息债务/股本或有负债/权益等特征以及对应的公司是否破产的信息,数据一共有6819条。下表对于数据中相关性最高的十项数据进行展示:

名称

数据类型

描述

Net Income to Total Assets

浮点型

净收入与总资产之比

ROA(A) before interest and % after tax

浮点型

利息前ROA(A)和税后百分比

ROA(B) before interest and depreciation after tax

浮点型

税后利息和折旧前的ROA(B)

ROA(C) before interest and depreciation before interest

浮点型

利息前ROA(C)和利息前折旧

Net worth/Assets

浮点型

净值/资产

Debt ratio %

浮点型

负债率%

Persistent EPS in the Last Four Seasons

浮点型

过去四个季度的持续每股收益

Retained Earnings to Total Assets

浮点型

留存收益占总资产的比例

Net profit before tax/Paid-in capital

浮点型

税前净利润/实收资本

Per Share Net profit before tax

浮点型

每股税前净利润

2.数据分析

2.1缺失值查找data.info()

经检测未发现数据中含有缺失值,

图形用户界面, 文本, 应用程序

描述已自动生成

2.2 数据集初探

def make_autopct(values):

def my_autopct(pct):

total = sum(values)

val = int(round(pct*total/100.0))

# 同时显示数值和占比的饼图

return '{p:.2f}% ({v:d})'.format(p=pct,v=val)

return my_autopct

plt.pie(label.value_counts(),labels=["没破产","破产了"],autopct=make_autopct(label.value_counts()))

# 解决中文显示问题

plt.rcParams['font.sans-serif'] = ['SimHei']

plt.rcParams['axes.unicode_minus'] = False

plt.legend()

plt.show()

图表, 饼图

描述已自动生成

我们不难发现绝大多数数据是破产了的数据,我们在挑选数据时候应该尽量把少量的破产了的数据都放入我们的训练集中这样才会有更好的效果。

2.3 数据相关性探查data.corr()

我们大致可以找到以下几个参数,有些参数增大会导致破产概率增大,而有些参数与破产概率负相关,他们增大以后公司的破产概率会减少。而且对于数据集的初探过程中我们发现,这些标签几乎都是数值类型的连续变量,所以我们不需要对其进行转化,也不需要使用独热编码,这些数据可以直接进行训练。

文本, 信件

描述已自动生成

corr_mat=data.corr()['Bankrupt?'].apply(np.abs).sort_values(ascending=False)

plt.rcParams['font.sans-serif'] = ['SimHei']

plt.rcParams['axes.unicode_minus'] = False

corr_mat[:11].plot(kind='barh')

图形用户界面, 图表, 条形图

描述已自动生成

2.4 变量直方图

简单观察变量分布,以上观测图是把变量分成50桶以后得出的结果。变量分布未见明显异常,基本上符合高斯分布。

data.hist(bins = 50, figsize = (40, 30))

plt.show()

图示, 工程绘图

描述已自动生成

2.5探索性数据分析

train_describe_std = data.describe().loc['std',:]

extreme_cols = train_describe_std[train_describe_std>10000].index.values

regular_cols = [col for col in data.columns[:-1] if col not in extreme_cols]

发现不同变量之间的量纲相差比较大,有的是十的若干次倍,而有的却连1都没有到,所以需要进行归一化处理。

3.研究问题:

通过6819条数据来探求公司破产可能与哪些因素有关,并且能够通过这些因素来判断一个公司在未来是否具有破产风险,本质上预测一个公司的破产是一种“分类任务”,最根本的任务就是找到特征与公司破产之间的关系。

4.研究方法

分类本身是一个监督学习的范畴,所以本实验选用监督学习中常用的分类器来进行模型预测,在一定的评估方法下通过网格搜索匹配最佳参数的方式来选择最优评估器。这四个模型分别是逻辑回归、补充朴素贝叶斯分类、随机森林和支持向量机。这四种模型的解释性都是比较强的且在二分类问题上都有用武之地。

5.方法流程

5.1数据归一化处理

我们观察到不同标签的数据的量纲不同,需要对数据进行归一化处理,这样做可以加快了梯度下降求最优解的速度,可以让逻辑回归,支持向量机这些模型更快。

from sklearn.preprocessing import MinMaxScaler

minmax = MinMaxScaler(feature_range=(0, 1))

X_train = minmax.fit_transform(X_train)

X_test = minmax.transform(X_test)

5.2 数据集划分

使用网格搜索进行五折交叉验证的方式来进行数据划分,本实验选用全部数据进行训练。

5.3 训练器

sklearn中已经封装了逻辑回归、高斯朴素贝叶斯分类、随机森林和支持向量机这些算法,但我们知道参数不可能适配所有的数据集,所以我们使用网格搜索来匹配参数,并且人为地根据模型复杂程度和数据集学习曲线来判断学习情况,并得出相应结果。

6.数据划分细节

1.采用train_test_split来进行划分,但是在划分前需要先打乱数据,因为在对数据预览的过程中发现,数据前半部分的破产企业的资料较多,后半部分的破产资料较少,所以要先乱序。设置test_size为0.2表示测试集占原来数据量的20%,训练集则占原来数据集的80%。

2.决策树、随机森林和支持向量机因为参数较多,需要使用网格搜索,而网格搜索GridSearchCV带有交叉验证的参数,我们设置五折交叉验证来进行训练。示意图如下[1]

# set seed for reproducibility

from sklearn.utils import shuffle

from sklearn.model_selection import train_test_split

data=shuffle(data)

# train-test split

X_train, X_test, y_train, y_test = train_test_split(data.iloc[:, 1:], data.iloc[:, 0], test_size=0.20)

图形用户界面

描述已自动生成

7.度量指标

对于这种样本量非常悬殊的二分类问题不能以accuracy也就是准确率作为唯一的判别标准,因为对于我们的实验目的来说,我们是要尽可能找出那些可能有破产风险的公司,所以我们要参考混淆矩阵,综合查全率查准率来进行定夺,必要时还可以绘制P-R曲线或者ROC曲线进行分析。本实验所采用score是roc-auc分数。

8.性能评价

8.1性能评价

模型的性能评价将从以下两个部分进行,学习能力、泛化能力。

学习能力是关于训练得到的模型关于训练样本集的预测能力。训练误差就是拿模型对训练集预测的结果与数据实际对应的结果进行比较,计算损失。

泛化能力指的是学习到的模型关于未知样本的预测能力。由于泛化误差难以估计,一般以测试花茶评价模型的泛化能力。泛化误差就是拿模型对测试集预测的结果与数据实际对应的结果进行比较,计算损失。[2]

所以本实验采用网格搜索来最优化模型性能和在限定值内调整参数。

8.2 参数调优:

8.2.1 对于逻辑回归我们进行网格搜索参数如下:

params = {'C': [1, 10, 100, 1000],

'class_weight': ['balanced', None],

'max_iter': [10000, 100000]}

log_model = GridSearchCV(LogisticRegression(), params, scoring = 'roc_auc', verbose = 1)

log_model.fit(X_train, y_train)

8.2.2 对于支持向量机我们进行网格搜索参数如下:

params = {'C': [1, 10, 100, 1000],

'kernel': ['linear', 'rbf'],

'gamma': [0.001, 0.0001],

'class_weight': ['balanced', None]}

8.2.3 对于补充朴素贝叶斯我们进行网格搜索参数如下:

params = {'alpha': [0.001, 0.01, 0.1, 1, 5, 10, 20]}

nb_model= GridSearchCV(ComplementNB(), params, scoring = 'roc_auc', verbose=1,cv=5)

8.2.4 对于随机森林我们进行网格搜索参数如下:

params = {'criterion': ['gini', 'entropy'],

'max_depth': [10, 50, 100, None],

'max_features': ['auto', None],

'n_estimators': [50, 100, 150]}

9.实验结果分析

9.1 逻辑回归结果分析

图形用户界面, 文本, 应用程序

描述已自动生成

最优参数是C=1,class_weight=”balanced”,max_iter=10000,我们将选用这个训练器进行预测,我们使用这样一个score是auc-roc分数(以下模型都是如此)但是还没有结束,我们还需要检查混淆矩阵的相关参数来评估模型好坏:

混淆矩阵和真正例率、假正例率、真负例率、假负例率如下:

 

roc曲线如下:

P-R曲线如下:

结论:不难看出在测试集上accuracy马马虎虎只有0.84,在那些真正破产=1的公司中,它的误报率很高,精确度也很低。不过好在他的召回率还蛮令人满意,基本上本来如果会破产的百分之七十多都能被找出来,这点还是不错的。

9.2 SVM结果分析

图形用户界面, 文本, 应用程序

描述已自动生成

最优参数是C=100,class_weight=”balanced”,gamma=0.001.

看上去与逻辑回归并没有差很多,混淆矩阵和真正例率、假正例率、真负例率、假负例率如下:

roc曲线如下:

P-R曲线如下:

结论: 即使应用了性能最好的超参数,SVM模型在这个不平衡的数据集上也表现不好。这一模型的误报率很高,在真正破产的公司中,准确率仅为14% ,耗时巨长,与逻辑回归相比并无明显优势。

9.3 补充朴素贝叶斯结果分析

图形用户界面, 文本, 应用程序, 聊天或短信

描述已自动生成

理论上补充朴素贝叶斯对这种不平衡的数据集效果良好,但是就此分数而言似乎不理想。最优参数为alpha=0.1

图形用户界面, 文本, 应用程序, 电子邮件

描述已自动生成

可能这种模型还是过于简单了,混淆矩阵和真正例率、假正例率、真负例率、假负例率如下:

图片包含 文本

描述已自动生成

roc曲线如下:

图形用户界面, 应用程序, Word

描述已自动生成

P-R曲线如下:

朴素贝叶斯模型的精度与支持向量机和逻辑回归模型一致,与其他模型一样,在真正破产=1的公司中,该模型的精度较低,而且这模型预测能力方面有明显低于前两者模型的表现。

9.4 随机森林结果分析

图形用户界面, 文本, 应用程序, 电子邮件

描述已自动生成

作为强学习器的随机森林还是表现出了非常高的水平。最优参数为criterion为信息熵,最大深度为10,基训练器为150,下面我们对其具体分类展开分析。

表格

描述已自动生成

它与之前的所有机器学习算法不同的是它对于破产=1的这部分数据拟合的精确度要更高而且高了近四成。

文本

中度可信度描述已自动生成

auc曲线如下:

P-R曲线如下:

结论: 与之前预测的模型一样,该模型在识别消极因素方面仍表现良好,但识别正例的能力却糟糕透顶。测试集中的大多数正例因素在最终模型中未被正确地识别这是实验前所未曾设想的,即使“准确性”指标很高,模型也没有真正发挥其识别少数群体的功能。

10.总结

图表, 条形图

描述已自动生成

这四种模型在这个数据集上的表现都很差。给定破产的完整数据集的比例(低于96%),我认为这些模型的性能差可能是由于数据集(和样本)不平衡造成的。我还想尝试对少数群体的样本进行过采样/有意加权,以便用更多的少数群体样本来训练模型,但这些都还只是停留在猜想阶段,对于这种非平衡数据集还是需要好的模型和好的数据增强的方式来学习。这应该不是学习器的错,是数据集拆分或者采样的艺术吧。

估计公司破产的风险对债权人和投资者来说非常重要,也是金融经济学的重要学科。因此,破产预测是一个重要的研究领域。近年来,人工智能和机器学习方法在企业破产预测方面取得了可喜的成果。因此,本研究使用四种机器学习算法,即逻辑回归、扩充朴素贝叶斯、随机森林、支持向量机来预测企业破产。这些数据是从1999年到2009年的《台湾经济杂志》上收集的,公司破产的定义是根据台湾证券交易所的业务规定而定的[3]

 

关键词:经济学、企业破产预测、机器学习、模型选择

11.全部代码

  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # In[1]:
  4. import pandas as pd
  5. import sklearn
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. get_ipython().run_line_magic('matplotlib', 'inline')
  9. # In[2]:
  10. data=pd.read_csv("data.csv")
  11. data.shape
  12. # In[3]:
  13. data.info()
  14. # In[4]:
  15. label=data.loc[:,"Bankrupt?"].copy()
  16. #data.drop(columns=["Bankrupt?"],inplace=True)
  17. # In[5]:
  18. def make_autopct(values):
  19. def my_autopct(pct):
  20. total = sum(values)
  21. val = int(round(pct*total/100.0))
  22. # 同时显示数值和占比的饼图
  23. return '{p:.2f}% ({v:d})'.format(p=pct,v=val)
  24. return my_autopct
  25. plt.pie(label.value_counts(),labels=["没破产","破产了"],autopct=make_autopct(label.value_counts()))
  26. # 解决中文显示问题
  27. plt.rcParams['font.sans-serif'] = ['SimHei']
  28. plt.rcParams['axes.unicode_minus'] = False
  29. plt.legend()
  30. plt.show()
  31. # In[6]:
  32. import seaborn as sns
  33. sns.heatmap(data.corr())
  34. # In[7]:
  35. corr_mat=data.corr()['Bankrupt?'].apply(np.abs).sort_values(ascending=False)
  36. plt.rcParams['font.sans-serif'] = ['SimHei']
  37. plt.rcParams['axes.unicode_minus'] = False
  38. corr_mat[:11].plot(kind='barh')
  39. # In[8]:
  40. data.hist(bins = 50, figsize = (40, 30))
  41. plt.show()
  42. # In[9]:
  43. train_describe_std = data.describe().loc['std',:]
  44. extreme_cols = train_describe_std[train_describe_std>10000].index.values
  45. regular_cols = [col for col in data.columns[:-1] if col not in extreme_cols]
  46. # In[10]:
  47. data.mean().sort_values()
  48. # In[11]:
  49. data.describe()
  50. # In[12]:
  51. # set seed for reproducibility
  52. from sklearn.utils import shuffle
  53. from sklearn.model_selection import train_test_split
  54. data=shuffle(data)
  55. # train-test split
  56. X_train, X_test, y_train, y_test = train_test_split(data.iloc[:, 1:], data.iloc[:, 0], test_size=0.20)
  57. # In[13]:
  58. from sklearn.preprocessing import MinMaxScaler
  59. minmax = MinMaxScaler(feature_range=(0, 1))
  60. X_train = minmax.fit_transform(X_train)
  61. X_test = minmax.transform(X_test)
  62. # In[14]:
  63. from sklearn.tree import DecisionTreeClassifier
  64. from sklearn.ensemble import RandomForestClassifier
  65. from sklearn.linear_model import LogisticRegression
  66. from sklearn.naive_bayes import ComplementNB
  67. from sklearn.svm import SVC
  68. from sklearn.model_selection import GridSearchCV
  69. from sklearn.model_selection import cross_val_score
  70. from sklearn.model_selection import train_test_split
  71. from sklearn.feature_selection import SelectPercentile
  72. from sklearn.metrics import classification_report
  73. from sklearn.metrics import average_precision_score
  74. from sklearn.metrics import precision_recall_curve
  75. from sklearn.metrics import confusion_matrix
  76. from sklearn.metrics import roc_curve
  77. from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_curve, RocCurveDisplay, auc, PrecisionRecallDisplay
  78. # ## 逻辑回归
  79. # In[53]:
  80. LogisticRegression().get_params()
  81. # In[ ]:
  82. params = {'C': [1, 10, 100, 1000],
  83. 'class_weight': ['balanced', None],
  84. 'max_iter': [10000, 100000]}
  85. log_model = GridSearchCV(LogisticRegression(), params, scoring = 'roc_auc', verbose = 1)
  86. log_model.fit(X_train, y_train)
  87. # In[49]:
  88. log_model.best_estimator_
  89. # In[50]:
  90. log_model.best_score_
  91. # In[15]:
  92. log_best = LogisticRegression(C=1, class_weight='balanced', max_iter=10000).fit(X_train, y_train)
  93. y_pred = log_best.predict(X_test)
  94. print(classification_report(y_test, y_pred))
  95. # In[16]:
  96. log_matrix = confusion_matrix(y_test, y_pred)
  97. print(log_matrix)
  98. log_tn = log_matrix[0][0]
  99. print('TN: ', log_tn)
  100. log_fn = log_matrix[1][0]
  101. print('FN: ', log_fn)
  102. log_tp = log_matrix[1][1]
  103. print('TP: ', log_tp)
  104. log_fp = log_matrix[0][1]
  105. print('FP: ', log_fp)
  106. # 真正例率
  107. print('TPR: ', log_tp/(log_tp + log_fn))
  108. # 假正例率
  109. print('FPR: ', log_fp/(log_fp + log_tn))
  110. # 真负例率
  111. print('TNR: ', log_tn/(log_tn + log_fp))
  112. # 假负例率
  113. print('FNR: ', log_fn/(log_fn + log_tp))
  114. # In[17]:
  115. ax=sns.heatmap(log_matrix,annot=True,fmt='.20g')
  116. ax.set_title("log_matrix")
  117. ax.set_xlabel("predict")
  118. ax.set_ylabel("true")
  119. # In[18]:
  120. # generate ROC curve
  121. fpr, tpr, thresholds = roc_curve(y_test, y_pred)
  122. roc_auc = auc(fpr, tpr)
  123. rocdisplay = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='Logistic Regression')
  124. rocdisplay.plot()
  125. # In[19]:
  126. prdisplay = PrecisionRecallDisplay.from_predictions(y_test, y_pred, name='Logistic Regression')
  127. # # svc
  128. # In[20]:
  129. SVC().get_params()
  130. # In[65]:
  131. params = {'C': [1, 10, 100, 1000],
  132. 'kernel': ['linear', 'rbf'],
  133. 'gamma': [0.001, 0.0001],
  134. 'class_weight': ['balanced', None]}
  135. # In[66]:
  136. svm_model = GridSearchCV(SVC(), params, scoring = 'roc_auc', cv=5,verbose = 1)
  137. # In[67]:
  138. svm_model.fit(X_train, y_train)
  139. # In[70]:
  140. svm_model.best_estimator_
  141. # In[71]:
  142. svm_model.best_score_
  143. # In[21]
  144. svm_best = SVC(C=100, class_weight='balanced', gamma=0.001, kernel='linear').fit(X_train, y_train)
  145. # In[22]:
  146. y_pred = svm_best.predict(X_test)
  147. # In[23]
  148. print(classification_report(y_test, y_pred))
  149. # In[24]:
  150. svm_matrix = confusion_matrix(y_test, y_pred)
  151. print(svm_matrix)
  152. svm_tn = svm_matrix[0][0]
  153. print('TN: ', svm_tn)
  154. svm_fn = svm_matrix[1][0]
  155. print('FN: ', svm_fn)
  156. svm_tp = svm_matrix[1][1]
  157. print('TP: ', svm_tp)
  158. svm_fp = svm_matrix[0][1]
  159. print('FP: ', svm_fp)
  160. # 真正例率
  161. print('TPR: ', svm_tp/(svm_tp + svm_fn))
  162. # 假正例率
  163. print('FPR: ', svm_fp/(svm_fp + svm_tn))
  164. # 真负例率
  165. print('TNR: ', svm_tn/(svm_tn + svm_fp))
  166. # 假负例率
  167. print('FNR: ', svm_fn/(svm_fn + svm_tp))
  168. # In[25]:
  169. ax=sns.heatmap(svm_matrix,annot=True,fmt='.20g')
  170. ax.set_title("svm_matrix")
  171. ax.set_xlabel("predict")
  172. ax.set_ylabel("true")
  173. # In[26]:
  174. # generate ROC curve
  175. fpr, tpr, thresholds = roc_curve(y_test, y_pred)
  176. roc_auc = auc(fpr, tpr)
  177. rocdisplay = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='SVM')
  178. rocdisplay.plot()
  179. # In[27]:
  180. prdisplay = PrecisionRecallDisplay.from_predictions(y_test, y_pred, name='SVM')
  181. # # 补充朴素贝叶斯
  182. # In[81]:
  183. params = {'alpha': [0.001, 0.01, 0.1, 1, 5, 10, 20]}
  184. nb_model= GridSearchCV(ComplementNB(), params, scoring = 'roc_auc', verbose=1,cv=5)
  185. # In[83]:
  186. nb_model.fit(X_train, y_train)
  187. # In[85]:
  188. nb_model.best_estimator_
  189. # In[86]:
  190. nb_model.best_score_
  191. # In[28]:
  192. nb_best = ComplementNB(alpha = 1).fit(X_train, y_train)
  193. y_pred = nb_best.predict(X_test)
  194. print(classification_report(y_test, y_pred))
  195. # In[29]:
  196. nb_matrix = confusion_matrix(y_test, y_pred)
  197. print(nb_matrix)
  198. nb_tn = nb_matrix[0][0]
  199. print('TN: ', nb_tn)
  200. nb_fn = nb_matrix[1][0]
  201. print('FN: ', nb_fn)
  202. nb_tp = nb_matrix[1][1]
  203. print('TP: ', nb_tp)
  204. nb_fp = nb_matrix[0][1]
  205. print('FP: ', nb_fp)
  206. # 真正例率
  207. print('TPR: ', nb_tp/(nb_tp + nb_fn))
  208. # 假正例率
  209. print('FPR: ', nb_fp/(nb_fp + nb_tn))
  210. # 真负例率
  211. print('TNR: ', nb_tn/(nb_tn + nb_fp))
  212. # 假负例率
  213. print('FNR: ', nb_fn/(nb_fn + nb_tp))
  214. # In[30]:
  215. ax=sns.heatmap(nb_matrix,annot=True,fmt='.20g')
  216. ax.set_title("nb_matrix")
  217. ax.set_xlabel("predict")
  218. ax.set_ylabel("true")
  219. # In[31]:
  220. # generate ROC curve
  221. fpr, tpr, thresholds = roc_curve(y_test, y_pred)
  222. roc_auc = auc(fpr, tpr)
  223. rocdisplay = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='Naive Bayes')
  224. rocdisplay.plot()
  225. # In[32]:
  226. prdisplay = PrecisionRecallDisplay.from_predictions(y_test, y_pred, name='Naive Bayes')
  227. # # 随机森林
  228. # In[92]:
  229. params = {'criterion': ['gini', 'entropy'],
  230. 'max_depth': [10, 50, 100, None],
  231. 'max_features': ['auto', None],
  232. 'n_estimators': [50, 100, 150]}
  233. # In[93]:
  234. rf_model = GridSearchCV(RandomForestClassifier(), params, scoring = 'roc_auc', verbose=1,cv=5)
  235. # In[94]:
  236. rf_model.fit(X_train, y_train)
  237. # In[95]:
  238. rf_model.best_estimator_
  239. # In[96]:
  240. rf_model.best_score_
  241. # In[33]:
  242. rf_best = RandomForestClassifier(criterion= 'entropy', max_depth= 10, max_features= None, n_estimators= 150).fit(X_train, y_train)
  243. # In[34]
  244. y_pred = rf_best.predict(X_test)
  245. # In[35]:
  246. print(classification_report(y_test, y_pred))
  247. # In[36]:
  248. rf_matrix = confusion_matrix(y_test, y_pred)
  249. print(rf_matrix)
  250. rf_tn = rf_matrix[0][0]
  251. print('TN: ', rf_tn)
  252. rf_fn = rf_matrix[1][0]
  253. print('FN: ', rf_fn)
  254. rf_tp = rf_matrix[1][1]
  255. print('TP: ', rf_tp)
  256. rf_fp = rf_matrix[0][1]
  257. print('FP: ', rf_fp)
  258. # 真正例率
  259. print('TPR: ', rf_tp/(rf_tp + rf_fn))
  260. # 假正例率
  261. print('FPR: ', rf_fp/(rf_fp + rf_tn))
  262. # 真负例率
  263. print('TNR: ', rf_tn/(rf_tn + rf_fp))
  264. # 假负例率
  265. print('FNR: ', rf_fn/(rf_fn + rf_tp))
  266. # In[37]:
  267. ax=sns.heatmap(rf_matrix,annot=True,fmt='.20g')
  268. ax.set_title("rf_matrix")
  269. ax.set_xlabel("predict")
  270. ax.set_ylabel("true")
  271. # In[38]:
  272. fpr, tpr, thresholds = roc_curve(y_test, y_pred)
  273. roc_auc = auc(fpr, tpr)
  274. rocdisplay = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name='Random Forest')
  275. rocdisplay.plot()
  276. # In[39]:
  277. prdisplay = PrecisionRecallDisplay.from_predictions(y_test, y_pred, name='Random Forest')
  278. # In[40]:
  279. df=pd.DataFrame([[log_tp/(log_tp + log_fn),log_fp/(log_fp + log_tn),log_tn/(log_tn + log_fp),log_fn/(log_fn + log_tp)],
  280. [svm_tp/(svm_tp + svm_fn),log_fp/(svm_fp + svm_tn),svm_tn/(svm_tn + svm_fp),svm_fn/(svm_fn + svm_tp)],
  281. [nb_tp/(nb_tp + nb_fn),nb_fp/(nb_fp + nb_tn),nb_tn/(nb_tn + nb_fp),nb_fn/(nb_fn + nb_tp)],
  282. [rf_tp/(rf_tp + rf_fn),rf_fp/(rf_fp + rf_tn),rf_tn/(rf_tn + rf_fp),rf_fn/(rf_fn + rf_tp)]])
  283. # In[41]:
  284. df.index=['log','svm','nb','rf']
  285. df.columns=['TPR','FPR','TNR','FNR']
  286. df
  287. # In[42]:
  288. x = np.arange(len(df))
  289. width = 0.2
  290. fig, ax = plt.subplots(figsize=(12,8))
  291. rects1 = ax.bar(x - width, df['TPR'], width, label='TPR')
  292. rects2 = ax.bar(x , df['FPR'], width, label='FPR')
  293. rects3 = ax.bar(x +width, df['TNR'], width, label='TNR')
  294. rects4 = ax.bar(x + width*2, df['FNR'], width, label='FNR')
  295. ax.set_ylabel('Rate')
  296. ax.set_title('The Score Of Each Estimator')
  297. ax.set_xticks(x, df.index)
  298. ax.legend()
  299. ax.bar_label(rects1, padding=2)
  300. ax.bar_label(rects2, padding=2)
  301. ax.bar_label(rects3, padding=2)
  302. ax.bar_label(rects4, padding=2)
  303. fig.tight_layout()
  304. plt.show()
posted @ 2022-12-23 10:19  我还没想好w  阅读(879)  评论(2编辑  收藏  举报