[AI]-模型测试和评价指标

模型测试

import cv2
from torchvision import transforms, datasets, models
from torch.utils.data import  DataLoader
import torch
import numpy as np
import os
from sklearn import metrics 
import matplotlib.pyplot as plt

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print(device)
num_class = 3
model_path = 模型路径
model = 模型(num_class).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()   # Set model to evaluate mode

test_dataset = 数据集读取(train=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

def turn(l):
    l = l.data.cpu().numpy()
    l = l.squeeze()
    l = np.swapaxes(l, 0, 2)
    l = np.swapaxes(l, 0, 1)
    return l

for inputs, labels in test_loader:
    model.to(device)
    inputs = inputs.to(device)
    labels = labels.to(device)

    pred = model(inputs)
    # pred = torch.relu(pred)
    pred = turn(pred)
    gt = turn(labels)

评价指标

混淆矩阵

以分割为例,经过.flatten()处理。

def acc(pred, gt):
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    num = len(pred)
    for i in range(num):
        if pred[i] > 0 and gt[i] == 1:
            tp += 1
        if pred[i] > 0 and gt[i] == 0:
            fp += 1
        if pred[i] == 0 and gt[i] == 1:
            fn += 1
        if pred[i] == 0 and gt[i] == 0:
            tn += 1
    acc = (tp + tn) / num
    iou = tp / (tp + fp + fn)
    rec = tp / (tp + fn)
    pre = tp / (tp + fp)
    f1 = 2 * pre * rec / (pre + rec)
    print("mAcc is :{},  mIou is :{}, recall is :{}, precision is :{}, f1 is :{}".format(acc, iou, rec, pre, f1))

ROC曲线图

def draw_roc(pred, gt, name):
    tpr, fpr, thresholds = metrics.roc_curve(gt, pred, pos_label=0)
    plt.figure
    plt.plot(fpr, tpr, label = name)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.legend(loc = 'lower right')
    plt.title(name)
    plt.savefig('路径/{}.png'.format(name))
    # plt.close()  如果有多个类别,不close()就会画在一张图上
posted @   CAMILIA  阅读(212)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 单元测试从入门到精通
· 上周热点回顾(3.3-3.9)
· winform 绘制太阳,地球,月球 运作规律
点击右上角即可分享
微信分享提示