numpy.argmax 用在求解混淆矩阵用

numpy.argmax

numpy.argmax(a, axis=None, out=None)[source]

Returns the indices of the maximum values along an axis.

Parameters:

a : array_like

Input array.

axis : int, optional

By default, the index is into the flattened array, otherwise along the specified axis.

out : array, optional

If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.

Returns:

index_array : ndarray of ints

Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.

See also

ndarray.argmax, argmin

amax
The maximum value along a given axis.
unravel_index
Convert a flat index into an index tuple.

Notes

In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.

Examples

>>> a = np.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> np.argmax(a)
5
>>> np.argmax(a, axis=0)
array([1, 1, 1])
>>> np.argmax(a, axis=1)
array([2, 2])
>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1

在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
if __name__ == "__main__":
    width, height = 32, 32
    X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height))
    trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666)
    print("sample data:")
    print(trainX[0])
    print(trainY[0])
    print(testX[-1])
    print(testY[-1])
 
    model = get_model(width, height, classes=100)
 
    filename = 'cnn_handwrite-acc0.8.tflearn'
    # try to load model and resume training
    #try:
    #    model.load(filename)
    #    print("Model loaded OK. Resume training!")
    #except:
    #    pass
 
    # Initialize our callback with desired accuracy threshold.
    early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6)
    try:
        model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                  snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                  show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite')
    except StopIteration as e:
        print("OK, stop iterate!Good!")
 
    model.save(filename)
 
    # predict all data and calculate confusion_matrix
    model.load(filename)
 
    pro_arr =model.predict(X)
    predict_labels = np.argmax(pro_arr, axis=1)
    print(classification_report(org_labels, predict_labels))
    print(confusion_matrix(org_labels, predict_labels))

 

posted @   bonelee  阅读(1410)  评论(0编辑  收藏  举报
编辑推荐:
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· DeepSeek 开源周回顾「GitHub 热点速览」
历史上的今天:
2017-05-01 BaezaYates 交集python和golang代码
2017-05-01 go 安装方法
点击右上角即可分享
微信分享提示