第10章 评价分类结果

 

10-1 准确度的陷阱和混淆矩阵

 

 

 

 

 

 

 

 

 10-2 精准率和召回率

 

 

 

 

 

 

 

 

 

 

 

 

10-3 实现混淆矩阵,精准率和召回率

 

Notbook 示例

 

 

 

 

Notbook 源码

复制代码
  1 [1]
  2 import numpy as np
  3 from sklearn import datasets
  4 [2]
  5 digits = datasets.load_digits()
  6 X = digits.data
  7 y = digits.target.copy()
  8 
  9 y[digits.target==9] = 1
 10 y[digits.target!=9] = 0
 11 [3]
 12 from sklearn.model_selection import train_test_split
 13 
 14 X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
 15 [4]
 16 X_train.shape
 17 (1347, 64)
 18 [5]
 19 from sklearn.linear_model import LogisticRegression
 20 
 21 log_reg = LogisticRegression()
 22 log_reg.fit(X_train,y_train)
 23 log_reg.score(X_test,y_test)
 24 F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
 25 STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
 26 
 27 Increase the number of iterations (max_iter) or scale the data as shown in:
 28     https://scikit-learn.org/stable/modules/preprocessing.html
 29 Please also refer to the documentation for alternative solver options:
 30     https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
 31   n_iter_i = _check_optimize_result(
 32 
 33 0.9755555555555555
 34 [6]
 35 y_log_predict = log_reg.predict(X_test)
 36 [7]
 37 def TN(y_ture, y_predict):
 38     assert len(y_ture) == len(y_predict)
 39     return np.sum((y_ture == 0) & (y_predict == 0))
 40 
 41 TN(y_test,y_log_predict)
 42 403
 43 [8]
 44 def FP(y_ture, y_predict):
 45     assert len(y_ture) == len(y_predict)
 46     return np.sum((y_ture == 0) & (y_predict == 1))
 47 
 48 FP(y_test,y_log_predict)
 49 2
 50 [9]
 51 def FN(y_ture, y_predict):
 52     assert len(y_ture) == len(y_predict)
 53     return np.sum((y_ture == 1) & (y_predict == 0))
 54 
 55 FN(y_test,y_log_predict)
 56 9
 57 [10]
 58 def TP(y_ture, y_predict):
 59     assert len(y_ture) == len(y_predict)
 60     return np.sum((y_ture == 1) & (y_predict == 1))
 61 
 62 TP(y_test,y_log_predict)
 63 36
 64 [11]
 65 def confusion_matrix(y_true, y_predict):
 66     return np.array([
 67         [TN(y_true,y_predict), FP(y_true,y_predict)],
 68         [FN(y_true,y_predict), TP(y_true,y_predict)]
 69     ])
 70 confusion_matrix(y_test, y_log_predict)
 71 array([[403,   2],
 72        [  9,  36]])
 73 [12]
 74 def precision_score(y_true, y_predict):
 75     tp = TP(y_true, y_predict)
 76     fp = FP(y_true, y_predict)
 77     try:
 78         return tp / (tp + fp)
 79     except:
 80         return 0.0
 81     
 82 precision_score(y_test, y_log_predict)
 83 0.9473684210526315
 84 [13]
 85 def recall_score(y_true, y_predict):
 86     tp = TP(y_true, y_predict)
 87     fn = FN(y_true, y_predict)
 88     try:
 89         return tp / (tp + fn)
 90     except:
 91         return 0.0
 92     
 93 recall_score(y_test, y_log_predict)
 94 0.8
 95 scikit-learn 中的混淆矩阵,精准率和召回率
 96 [14]
 97 from sklearn.metrics import confusion_matrix
 98 
 99 confusion_matrix(y_test, y_log_predict)
100 array([[403,   2],
101        [  9,  36]], dtype=int64)
102 [15]
103 from sklearn.metrics import precision_score
104 
105 precision_score(y_test,y_log_predict)
106 0.9473684210526315
107 [16]
108 from sklearn.metrics import recall_score
109 
110 recall_score(y_test,y_log_predict)
111 0.8
复制代码

 

10-4 F1 Score

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

10-5 Precision-Recall 平衡

 

 

 

 

 

 

 

 

 Notbook 示例

 

 

 

Notbook 源码

复制代码
  1 F1 Score
  2 [18]
  3 import numpy as np
  4 import matplotlib.pyplot as plt
  5 [4]
  6 def f1_score(precision, recall):
  7     try:
  8         return 2 * precision * recall / ( precision + recall)
  9     except:
 10         return 0.0
 11 [5]
 12 precision = 0.5
 13 recall =0.5
 14 f1_score(precision,recall)
 15 0.5
 16 [6]
 17 precision = 0.1
 18 recall = 0.9
 19 f1_score(precision,recall)
 20 0.18000000000000002
 21 [7]
 22 precision = 0.0
 23 recall = 0.9
 24 f1_score(precision,recall)
 25 0.0
 26 [8]
 27 from sklearn import datasets
 28 
 29 digits = datasets.load_digits()
 30 X = digits.data
 31 y = digits.target.copy()
 32 
 33 y[digits.target==9] = 1
 34 y[digits.target!=9] = 0
 35 [9]
 36 from sklearn.model_selection import train_test_split
 37 
 38 X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
 39 [10]
 40 from sklearn.linear_model import LogisticRegression
 41 
 42 log_reg = LogisticRegression()
 43 log_reg.fit(X_train,y_train)
 44 log_reg.score(X_test,y_test)
 45 F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
 46 STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
 47 
 48 Increase the number of iterations (max_iter) or scale the data as shown in:
 49     https://scikit-learn.org/stable/modules/preprocessing.html
 50 Please also refer to the documentation for alternative solver options:
 51     https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
 52   n_iter_i = _check_optimize_result(
 53 
 54 0.9755555555555555
 55 [12]
 56 y_predict = log_reg.predict(X_test)
 57 [14]
 58 from sklearn.metrics import confusion_matrix
 59 
 60 confusion_matrix(y_test, y_predict)
 61 array([[403,   2],
 62        [  9,  36]], dtype=int64)
 63 [15]
 64 from sklearn.metrics import precision_score
 65 
 66 precision_score(y_test,y_predict)
 67 0.9473684210526315
 68 [16]
 69 from sklearn.metrics import recall_score
 70 
 71 recall_score(y_test,y_predict)
 72 0.8
 73 [17]
 74 from sklearn.metrics import f1_score
 75 
 76 f1_score(y_test, y_predict)
 77 0.8674698795180723
 78 [19]
 79 log_reg.decision_function(X_test)
 80 array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
 81        -48.03305046, -24.18254017, -44.60990955, -24.24479014,
 82         -1.14284305, -19.00457455, -65.82296325, -50.97066379,
 83        -30.92082895, -45.94864685, -37.36152924, -29.51329291,
 84        -36.92856241, -82.80968102, -37.63648469,  -9.8788178 ,
 85         -9.26807376, -85.25151511, -16.75031683, -45.3443087 ,
 86         -5.02564992, -48.29794851, -11.65881308, -37.36076018,
 87        -25.08299918, -13.59764839, -16.5953793 , -28.78598514,
 88        -34.3678504 , -28.52337297,  -8.11452445,  -4.6022814 ,
 89        -21.94247061, -21.87781719, -31.17562964, -23.36466695,
 90        -26.90556959, -62.23610493, -37.68704357, -66.36559349,
 91        -20.10364224, -16.68553543, -18.16727295, -21.5492968 ,
 92        -28.96549149, -19.61417448,   2.41242539,   7.7293895 ,
 93        -34.87176036, -42.70947089, -25.63234617, -34.75112951,
 94         -7.59781243, -49.51333048, -51.52646722,  19.66201134,
 95        -10.09725489, -32.0060884 , -11.49932898,  -1.42857622,
 96        -48.69518674, -43.87320098, -24.83993002, -19.60221328,
 97        -36.64215638,  -3.52332398,  -4.44425929, -19.2097096 ,
 98        -20.35743524, -40.89507478, -11.8601531 , -32.7541669 ,
 99        -35.7587069 , -28.5992766 , -55.41729445, -18.82659602,
100          4.56820284, -16.46610285, -76.77533257, -58.24489386,
101        -30.24372047, -29.42228053, -33.41709641,  -8.41820483,
102        -47.91658806, -65.49746283, -16.90883929, -22.17253788,
103        -11.28533349, -18.66745327, -69.22403985, -46.39517132,
104        -39.45322992, -35.92419637, -17.72138133, -62.96856734,
105        -16.85788403, -55.14488072, -28.77104338, -68.47963152,
106        -68.85398745,  -6.50137137, -25.51784658, -38.31116618,
107        -27.46927833, -15.54375029, -27.47815541, -20.3332547 ,
108         12.07445747, -23.0874899 , -35.96861875, -29.87593015,
109        -68.95687582, -27.32891417, -54.23494371, -24.63214107,
110        -11.85499344, -47.3668394 ,  -2.75048074, -59.68909997,
111        -30.98860082,  -8.98734123, -70.83680244, -56.97836911,
112        -20.07706325, -21.49966977, -68.28663666, -18.91058226,
113        -38.59829624, -57.36383144,  -0.91081426, -22.51004028,
114        -22.66179993, -28.99910954, -32.78451092, -20.43310932,
115        -11.3535947 ,   4.63057398,   6.26725227,   1.48867388,
116         -7.63736213, -39.24004802,  12.15620508, -74.5437931 ,
117        -75.08648846, -49.97467006, -11.63081865, -47.61958938,
118        -75.41232907, -29.89880625, -63.93514052,  -7.26078617,
119         -6.64271099, -18.2199428 , -32.47674504, -17.93503126,
120        -43.33439089, -32.70727873, -34.29947784, -72.74689478,
121        -15.19084634,  11.48054014, -56.40994066,  -6.03930048,
122        -48.38612896, -16.44647469,  -2.13693844, -11.85713489,
123        -33.26559831, -51.34042787, -10.38651041, -17.18846078,
124         -5.23982411, -25.19373985, -15.70686294,   3.5534034 ,
125        -45.03772698, -12.58192379, -25.37999195, -16.56801256,
126        -22.17722688, -82.50131039,  -5.8811552 , -20.25621041,
127        -20.46383207, -26.80997392, -25.98518361, -40.44912794,
128        -38.01122059, -26.9627282 , -23.75636279, -20.15726322,
129         -9.69213637, -19.6799691 , -42.49289639, -44.13469938,
130        -15.65386714, -64.03047268, -24.55648146, -56.30568399,
131        -13.01339393, -29.66652546,   3.89794499, -44.33546306,
132         -7.92245618,   1.14543666,  -2.81814751, -11.92929586,
133          7.5086596 ,  -7.17718348, -46.39847023, -48.65871982,
134         -4.59959364, -19.05437356, -24.07254218, -48.76355552,
135        -15.01620526, -24.92137044, -16.69772054, -18.68326579,
136        -15.70208152, -16.86386928, -38.52705695, -31.09380281,
137         -9.37781861, -71.4453079 , -22.76526306, -14.43837784,
138        -23.08137726, -34.31916589,  -0.89221103, -32.73888374,
139        -11.21723013, -18.6738182 ,  -8.21484026, -45.43305526,
140        -22.30560288, -62.38971913, -46.77028519, -65.15237525,
141        -33.22628484, -23.47536421, -28.51024714, -64.78914741,
142          1.45290051,  -4.09358964, -25.64587602, -22.32038298,
143        -54.68656406, -16.3407006 , -12.06726537, -35.28199188,
144         -5.7391347 , -13.52396326, -72.2770459 ,  -6.16552202,
145         -1.16494995, -35.58095254, -24.15372831, -68.3152937 ,
146         14.76606277, -63.0626057 ,   9.9115143 , -24.1477828 ,
147        -32.45732897, -14.38796233, -85.7282472 , -12.77864747,
148          8.99482139, -16.51791403, -36.67219629, -16.51511131,
149        -19.35718611, -32.583308  ,  -5.64342385,   7.68471894,
150          9.38946768,   5.85378475, -35.64899776, -12.98316031,
151        -54.42344306, -41.10888515,   5.63263711, -79.47912897,
152        -15.82650933, -19.23205602, -10.86309466, -42.52164565,
153        -19.81792269, -15.70492451, -17.99800508, -18.02255039,
154         -6.75867766, -20.78794591, -16.58125173, -70.42110518,
155         -9.21349451, -31.70399615, -19.67558207, -21.95918435,
156        -24.77110999, -16.38822309, -13.36794196, -22.93287663,
157         11.06093377, -15.37076191, -32.94045314, -13.74640562,
158        -50.35815794, -20.45538215, -56.2709184 , -28.68677373,
159        -21.86524573, -30.41664698, -69.26034763, -59.34711621,
160         14.34093357,   8.57797635, -25.66219805,   2.74054632,
161          4.9329685 , -19.66539612, -58.82488345, -10.00833742,
162        -28.80946298, -27.20346821,   6.28874155, -80.46777388,
163        -34.45717484, -50.28471677, -35.95066935, -48.6313621 ,
164        -18.01210551, -62.3428243 ,  -3.09974615, -25.2635612 ,
165        -64.10526345,  -9.61660605, -21.76591374,  19.89900139,
166        -18.75262552,  -4.46636384, -13.15019258, -21.64298339,
167        -43.10021867, -52.10329918, -28.53126446, -14.54900274,
168         -2.47647559,  -6.12117544,   3.69187156, -15.0063578 ,
169        -40.85876851, -26.64359518,  14.10780389, -17.68798006,
170         15.18161223, -33.09641501,   5.26048113, -14.27034463,
171        -53.58418085, -50.04146827, -30.668069  , -38.05244113,
172        -23.29209606, -24.6960092 , -13.57354354, -22.62553141,
173        -27.2290141 , -19.64733979, -28.1768732 , -19.93558149,
174        -29.85262347, -11.28766344, -17.24377394, -24.0310721 ,
175        -24.35542295,  10.39150921, -17.21009704, -38.02155334,
176        -16.08422171, -37.57447399, -16.327524  , -69.12211344,
177        -33.67776297, -43.62662563, -26.61467625, -10.32511698,
178        -66.36070209, -31.9032331 , -45.56406403, -14.57833594,
179        -36.13656958, -14.94377141, -70.01819354, -11.35647733,
180        -40.86227952, -32.65545084, -19.77146533, -27.58157471,
181        -15.73466776, -31.57608305,  -8.50558639, -21.38402622,
182        -34.07101343, -11.68747617, -36.42460337, -34.78640679,
183        -22.21781815,   4.77423291, -21.31044306,  -4.45343862,
184        -20.8192745 , -32.26057776, -41.11472384, -25.0841837 ,
185        -19.76245188, -47.86598828, -30.89389022, -45.55549885,
186        -71.52150073,  -6.25498279, -32.5635314 ,   2.27397922,
187         11.93710255,   7.1181192 , -31.36293349, -63.9582492 ,
188        -23.78891268,  -5.73651065, -32.42584299, -24.7138706 ,
189        -67.69974056, -32.8331123 , -33.60887574, -31.53192719,
190        -51.97754435, -22.54575078,  -7.74388421, -17.30052337,
191        -25.78866235, -32.37585686, -29.48393512, -66.43195243,
192        -45.70161834, -16.05036959])
193 [20]
194 log_reg.decision_function(X_test)[:10]
195 array([-21.39857276, -32.89731271, -16.41797156, -79.82318954,
196        -48.03305046, -24.18254017, -44.60990955, -24.24479014,
197         -1.14284305, -19.00457455])
198 [21]
199 log_reg.predict(X_test)[:10]
200 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
201 [22]
202 decision_scores = log_reg.decision_function(X_test)
203 [23]
204 np.min(decision_scores)
205 -85.72824719706493
206 [24]
207 np.max(decision_scores)
208 19.899001392471213
209 [25]
210 y_predict_2 = np.array(decision_scores >= 5, dtype='int')
211 [26]
212 confusion_matrix(y_test, y_predict_2)
213 array([[404,   1],
214        [ 21,  24]], dtype=int64)
215 [27]
216 precision_score(y_test,y_predict_2)
217 0.96
218 [28]
219 recall_score(y_test,y_predict_2)
220 0.5333333333333333
221 [29]
222 y_predict_3 = np.array(decision_scores >= -5, dtype='int')
223 [30]
224 confusion_matrix(y_test, y_predict_3)
225 array([[390,  15],
226        [  5,  40]], dtype=int64)
227 [31]
228 precision_score(y_test,y_predict_3)
229 0.7272727272727273
230 [32]
231 recall_score(y_test,y_predict_3)
232 0.8888888888888888
复制代码

 

10-6 scikit-learn 中的Precision-Recall曲线

 

 

 

 

Notbook 示例

 

 

Notbook 源码

复制代码
 1 [8]
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 from sklearn import datasets
 5 [2]
 6 digits = datasets.load_digits()
 7 X = digits.data
 8 y = digits.target.copy()
 9 
10 y[digits.target==9] = 1
11 y[digits.target!=9] = 0
12 [3]
13 from sklearn.model_selection import train_test_split
14 
15 X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
16 [5]
17 from sklearn.linear_model import LogisticRegression
18 
19 log_reg = LogisticRegression()
20 log_reg.fit(X_train,y_train)
21 decision_score = log_reg.decision_function(X_test)
22 F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
23 STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
24 
25 Increase the number of iterations (max_iter) or scale the data as shown in:
26     https://scikit-learn.org/stable/modules/preprocessing.html
27 Please also refer to the documentation for alternative solver options:
28     https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
29   n_iter_i = _check_optimize_result(
30 
31 [6]
32 from sklearn.metrics import precision_score
33 from sklearn.metrics import recall_score
34 
35 precisions = []
36 recallls = []
37 thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
38 for threshold in thresholds:
39     y_predict = np.array(decision_score >= threshold,dtype='int')
40     precisions.append(precision_score(y_test, y_predict))
41     recallls.append(recall_score(y_test,y_predict))
42 [9]
43 plt.plot(thresholds,precisions)
44 plt.plot(thresholds,recallls)
45 [<matplotlib.lines.Line2D at 0x1bf900a73a0>]
46 
47 Precision-Recall 曲线
48 [10]
49 plt.plot(precisions, recallls)
50 
51 [<matplotlib.lines.Line2D at 0x1bf9016b4c0>]
52 
53 scikit-learn 中的Precision-Recall曲线
54 [11]
55 from sklearn.metrics import precision_recall_curve
56 
57 precisions, recallls, thresholds = precision_recall_curve(y_test,decision_score)
58 [12]
59 precisions.shape
60 (151,)
61 [13]
62 recallls.shape
63 (151,)
64 [14]
65 thresholds.shape
66 (150,)
67 [15]
68 plt.plot(thresholds,precisions[:-1])
69 plt.plot(thresholds,recallls[:-1])
70 [<matplotlib.lines.Line2D at 0x1bf901f6fa0>]
71 
72 [16]
73 plt.plot(precisions, recallls)
74 [<matplotlib.lines.Line2D at 0x1bf901e8160>]
复制代码

 

 

10-7 ROC

 

 

 

 

 

 

 

 

 

 

 

Notbook 示例

 

 

Notbook 源码

复制代码
 1 ROC曲线
 2 [1]
 3 import numpy as np
 4 import matplotlib.pyplot as plt
 5 from sklearn import datasets
 6 [2]
 7 digits = datasets.load_digits()
 8 X = digits.data
 9 y = digits.target.copy()
10 
11 y[digits.target==9] = 1
12 y[digits.target!=9] = 0
13 [3]
14 from sklearn.model_selection import train_test_split
15 
16 X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
17 [4]
18 from sklearn.linear_model import LogisticRegression
19 
20 log_reg = LogisticRegression()
21 log_reg.fit(X_train,y_train)
22 decision_score = log_reg.decision_function(X_test)
23 F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
24 STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.
25 
26 Increase the number of iterations (max_iter) or scale the data as shown in:
27     https://scikit-learn.org/stable/modules/preprocessing.html
28 Please also refer to the documentation for alternative solver options:
29     https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
30   n_iter_i = _check_optimize_result(
31 
32 [5]
33 from playML.metrics import FPR, TPR
34 
35 fprs = []
36 tprs = []
37 thresholds = np.arange(np.min(decision_score), np.max(decision_score),0.1)
38 for threshold in thresholds:
39     y_predict = np.array(decision_score >= threshold,dtype='int')
40     fprs.append(FPR(y_test, y_predict))
41     tprs.append(TPR(y_test,y_predict))
42 [6]
43 plt.plot(fprs,tprs)
44 [<matplotlib.lines.Line2D at 0x1600c413820>]
45 
46 scikit-learn中的ROC
47 [7]
48 from sklearn.metrics import roc_curve
49 
50 fps, tps, thresholds = roc_curve(y_test,decision_score)
51 [8]
52 plt.plot(fprs,tprs)
53 [<matplotlib.lines.Line2D at 0x1600c190cd0>]
54 
55 [9]
56 from sklearn.metrics import roc_auc_score
57 
58 roc_auc_score(y_test,decision_score)
59 0.9823868312757201
复制代码

 

10-8 多分类问题中的混淆矩阵

Notbook 示例

 

 

Notbook 源码

复制代码
多分类问题中的混淆矩阵
[2]
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
[5]
digits = datasets.load_digits()
X = digits.data
y = digits.target.copy()
[6]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=666)
[8]
from sklearn.linear_model import LogisticRegression

log_reg = LogisticRegression()
log_reg.fit(X_train,y_train)
log_reg.score(X_test,y_test)
F:\anaconda\lib\site-packages\sklearn\linear_model\_logistic.py:814: ConvergenceWarning: lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(

0.9711111111111111
[9]
y_predict = log_reg.predict(X_test)
[10]
from sklearn.metrics import precision_score

precision_score(y_test,y_predict,average='micro')
0.9711111111111111
[11]
from sklearn.metrics import confusion_matrix

confusion_matrix(y_test,y_predict)
array([[46,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0, 40,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0, 50,  0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  1, 50,  0,  0,  0,  0,  1,  1],
       [ 0,  0,  0,  0, 47,  0,  0,  0,  1,  0],
       [ 0,  0,  0,  0,  0, 37,  0,  1,  0,  0],
       [ 0,  0,  0,  0,  0,  1, 38,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0, 43,  0,  0],
       [ 0,  0,  0,  0,  1,  2,  0,  1, 44,  0],
       [ 0,  0,  0,  1,  0,  2,  0,  0,  0, 42]], dtype=int64)
[12]
cfm = confusion_matrix(y_test,y_predict)
plt.matshow(cfm, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf0153d0>

[14]
row_sums = np.sum(cfm, axis=1)
err_matrix = cfm / row_sums
np.fill_diagonal(err_matrix,0)
err_matrix
array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.02      , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.02083333, 0.02222222],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.02083333, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.02325581, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.02631579, 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.02083333,
        0.05263158, 0.        , 0.02325581, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.01886792, 0.        ,
        0.05263158, 0.        , 0.        , 0.        , 0.        ]])
[15]
plt.matshow(err_matrix, cmap=plt.cm.gray)
<matplotlib.image.AxesImage at 0x1cdcf386550>
复制代码

 

posted @   Cai-Gbro  阅读(40)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
点击右上角即可分享
微信分享提示