numpy.argmax 用在求解混淆矩阵用
numpy.argmax
numpy.
argmax
(a, axis=None, out=None)[source]-
Returns the indices of the maximum values along an axis.
Parameters: a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
Returns: index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
See also
amax
- The maximum value along a given axis.
unravel_index
- Convert a flat index into an index tuple.
Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) >>> a array([[0, 1, 2], [3, 4, 5]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2])
>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1
在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别1234567891011121314151617181920212223242526272829303132333435363738if
__name__
=
=
"__main__"
:
width, height
=
32
,
32
X, Y, org_labels
=
load_data(dirname
=
"data"
, resize_pics
=
(width, height))
trainX, testX, trainY, testY
=
train_test_split(X, Y, test_size
=
0.2
, random_state
=
666
)
print
(
"sample data:"
)
print
(trainX[
0
])
print
(trainY[
0
])
print
(testX[
-
1
])
print
(testY[
-
1
])
model
=
get_model(width, height, classes
=
100
)
filename
=
'cnn_handwrite-acc0.8.tflearn'
# try to load model and resume training
#try:
# model.load(filename)
# print("Model loaded OK. Resume training!")
#except:
# pass
# Initialize our callback with desired accuracy threshold.
early_stopping_cb
=
EarlyStoppingCallback(val_acc_thresh
=
0.6
)
try
:
model.fit(trainX, trainY, validation_set
=
(testX, testY), n_epoch
=
500
, shuffle
=
True
,
snapshot_epoch
=
True
,
# Snapshot (save & evaluate) model every epoch.
show_metric
=
True
, batch_size
=
32
, callbacks
=
early_stopping_cb, run_id
=
'cnn_handwrite'
)
except
StopIteration as e:
print
(
"OK, stop iterate!Good!"
)
model.save(filename)
# predict all data and calculate confusion_matrix
model.load(filename)
pro_arr
=
model.predict(X)
predict_labels
=
np.argmax(pro_arr, axis
=
1
)
print
(classification_report(org_labels, predict_labels))
print
(confusion_matrix(org_labels, predict_labels))
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
2017-05-01 BaezaYates 交集python和golang代码
2017-05-01 go 安装方法