• 博客园logo
  • 会员
  • 周边
  • 众包
  • 新闻
  • 博问
  • 闪存
  • 赞助商
  • Chat2DB
    • 搜索
      所有博客
    • 搜索
      当前博客
  • 写随笔 我的博客 短消息 简洁模式
    用户头像
    我的博客 我的园子 账号设置 会员中心 简洁模式 ... 退出登录
    注册 登录

LR233

  • 博客园
  • 联系
  • 订阅
  • 管理

公告

View Post

2、混淆矩阵

1、基础方法

from sklearn.metrics import confusion_matrix        # 导包
cm = confusion_matrix(targets, predictions)     # 传入的参数为真实值和预测值

 

2、自定义的方法如下:

 1 def plot_confusion_matrix(y_true, y_pred, classes,
 2                                            normalize=False,
 3                                            title=None,
 4                                            cmap=plt.cm.Blues):
 5     """
 6     这个函数打印并绘制混淆矩阵。
 7     规范化可以通过设置 'normalize=True' 来应用
 8     """
 9     if not title:
10         if normalize:
11             title = 'Normalized confusion matrix'
12         else:
13             title = 'Confusion matrix, without normalization'
14 
15     # Compute confusion matrix(横轴为预测值,纵轴为实际值)
16     cm = metrics.confusion_matrix(y_true, y_pred)    
17     # 只使用出现在数据中的标签
18     classes = unique_labels(y_true, y_pred)    # 提取唯一标签(只要出现过都进行统计)
19     if normalize:
20         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]    # 搞成多维一列
21 
22     fig, ax = plt.subplots()
23     im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
24     ax.figure.colorbar(im, ax=ax)
25     # 我们要显示所有的刻度...
26     ax.set(xticks=np.arange(cm.shape[1]),    # shape[1]为列数
27            yticks=np.arange(cm.shape[0]),    # shape[0]为行数
28            # ... 并用各自的列表项为它们标记
29            xticklabels=fruits, yticklabels=fruits,
30            title=title,
31            ylabel='True label',
32            xlabel='Predicted label')
33 
34     # 旋转刻度标签并设置它们的对齐方式。
35     plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
36              rotation_mode="anchor")
37 
38     # 遍历数据维度并创建文本注释.
39     fmt = '.2f' if normalize else 'd'
40     thresh = cm.max() / 2.
41     for i in range(cm.shape[0]):
42         for j in range(cm.shape[1]):
43             ax.text(j, i, format(cm[i, j], fmt),
44                         ha="center", va="center",
45                         color="white" if cm[i, j] > thresh else "black")
46     fig.tight_layout()
47     return cm,ax

绘制混淆矩阵

1 cm , _ = plot_confusion_matrix(y_test, y_pred, classes=y_train, normalize=True, title='Normalized confusion matrix')
2 plt.show()

 效果如下:

     

posted on 2022-11-02 14:14  LR233  阅读(244)  评论(0)    收藏  举报

刷新页面返回顶部
 
博客园  ©  2004-2025
浙公网安备 33010602011771号 浙ICP备2021040463号-3