numpy.argmax 用在求解混淆矩阵用
numpy.argmax
numpy.
argmax
(a, axis=None, out=None)[source]-
Returns the indices of the maximum values along an axis.
Parameters: a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should be of the appropriate shape and dtype.
Returns: index_array : ndarray of ints
Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.
See also
amax
- The maximum value along a given axis.
unravel_index
- Convert a flat index into an index tuple.
Notes
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
Examples
>>> a = np.arange(6).reshape(2,3) >>> a array([[0, 1, 2], [3, 4, 5]]) >>> np.argmax(a) 5 >>> np.argmax(a, axis=0) array([1, 1, 1]) >>> np.argmax(a, axis=1) array([2, 2])
>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1
在多分类模型训练中,我的使用:org_labels = [0,1,2,....max_label] 从0开始的标记类别if __name__ == "__main__": width, height = 32, 32 X, Y, org_labels = load_data(dirname="data", resize_pics=(width, height)) trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.2, random_state=666) print("sample data:") print(trainX[0]) print(trainY[0]) print(testX[-1]) print(testY[-1]) model = get_model(width, height, classes=100) filename = 'cnn_handwrite-acc0.8.tflearn' # try to load model and resume training #try: # model.load(filename) # print("Model loaded OK. Resume training!") #except: # pass # Initialize our callback with desired accuracy threshold. early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.6) try: model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True, snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch. show_metric=True, batch_size=32, callbacks=early_stopping_cb, run_id='cnn_handwrite') except StopIteration as e: print("OK, stop iterate!Good!") model.save(filename) # predict all data and calculate confusion_matrix model.load(filename) pro_arr =model.predict(X) predict_labels = np.argmax(pro_arr, axis=1) print(classification_report(org_labels, predict_labels)) print(confusion_matrix(org_labels, predict_labels))