混淆矩阵的理解

以手写数字识别为例

"""
# @Time    :  2020/9/7
# @Author  :  Jimou Chen
"""
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler  # 减去平均值再除以方差
from sklearn.metrics import classification_report, confusion_matrix

if __name__ == '__main__':

    digits_data = load_digits()
    x_data = digits_data.data
    y_data = digits_data.target

    # 对数据进行标准化
    sc = StandardScaler()
    x_data = sc.fit_transform(x_data)
    # 切分数据
    x_train, x_test, y_train, y_test = train_test_split(x_data, y_data)
    # 建模
    model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=100)
    model.fit(x_train, y_train)

    # 预测
    prediction = model.predict(x_test)
    # 评估
    print(classification_report(prediction, y_test))
    print(confusion_matrix(y_test, prediction))
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        44
           1       1.00      0.96      0.98        50
           2       0.98      0.98      0.98        48
           3       0.97      0.97      0.97        35
           4       0.98      0.98      0.98        42
           5       0.96      1.00      0.98        48
           6       1.00      1.00      1.00        43
           7       1.00      0.98      0.99        45
           8       0.96      1.00      0.98        49
           9       0.98      0.96      0.97        46

    accuracy                           0.98       450
   macro avg       0.98      0.98      0.98       450
weighted avg       0.98      0.98      0.98       450

[[44  0  0  0  0  0  0  0  0  0]
 [ 0 48  0  0  0  0  0  0  0  0]
 [ 0  1 47  0  0  0  0  0  0  0]
 [ 0  0  0 34  0  0  0  1  0  0]
 [ 0  0  0  0 41  0  0  0  0  1]
 [ 0  0  0  0  1 48  0  0  0  1]
 [ 0  0  0  0  0  0 43  0  0  0]
 [ 0  0  0  0  0  0  0 44  0  0]
 [ 0  1  1  0  0  0  0  0 49  0]
 [ 0  0  0  1  0  0  0  0  0 44]]

Process finished with exit code 0

confusion_matrix理解

  • 如下
[[44  0  0  0  0  0  0  0  0  0]
 [ 0 48  0  0  0  0  0  0  0  0]
 [ 0  1 47  0  0  0  0  0  0  0]
 [ 0  0  0 34  0  0  0  1  0  0]
 [ 0  0  0  0 41  0  0  0  0  1]
 [ 0  0  0  0  1 48  0  0  0  1]
 [ 0  0  0  0  0  0 43  0  0  0]
 [ 0  0  0  0  0  0  0 44  0  0]
 [ 0  1  1  0  0  0  0  0 49  0]
 [ 0  0  0  1  0  0  0  0  0 44]]
  • 对角线越大越好, 最理想的情况是只有对角线有值
  • 其他地方出现值代表该分类被识别成其他的类别
  • 比如上面第0、1行很完美,说明都识别正确了
  • 但是第2行对角线有个1,说明有一个图片本来是2的,却识别成了1
  • 其他同理
posted @ 2020-09-07 16:35  JackpotNeaya  阅读(512)  评论(0编辑  收藏  举报