将不同yolo的PR图绘制在同一张图上

步骤:

  • 首先需要将

1.修改PR绘制源码--保存绘制数据

  • yolo11代码路径:/ultralytics/utils/metrics.py
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
    """Plots a precision-recall curve."""
    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
    py = np.stack(py, axis=1)

    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
        for i, y in enumerate(py.T):
            ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}")  # plot(recall, precision)
            # 保存每个类别曲线的xy
            # if 'Pose' in str(save_dir):
            #     with open(f'pr_data/11n/{names[i]}_pose.csv', 'w+') as f:
            #         for px_v, y_v in zip(px, y):
            #             f.write(f'{px_v},{y_v}\n')
            # else:
            #     with open(f'pr_data/11n/{names[i]}_Box.csv', 'w+') as f:
            #         for px_v, y_v in zip(px, y):
            #             f.write(f'{px_v},{y_v}\n')
    else:
        ax.plot(px, py, linewidth=1, color="grey")  # plot(recall, precision)

    ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
    # 保存all这条曲线的xy
    with open('pr_data/11n.csv', 'w+') as f:
        for px_v, mean_y_v in zip(px, py.mean(1)):
            f.write(f'{px_v},{mean_y_v}\n')
            
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    ax.set_title("Precision-Recall Curve")
    fig.savefig(save_dir, dpi=250)
    plt.close(fig)
    if on_plot:
        on_plot(save_dir)

2.运行val.py

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('/yolo11/yolo11-2/runs/train/kitti-yolo11/weights/best.pt') # 选择训练好的权重路径
    model.val(data='/Object_detection/LS/yolo11/yolo11-1/ultralytics/cfg/datasets/kitti.yaml',
              split='val', # split可以选择train、val、test 根据自己的数据集情况来选择.
              imgsz=640,
              batch=16,
              project='runs/val',
              name='exp',
              )

3.运行结果

  • 同样的运行需要绘制的不同模型,只需更改第一步中的保存的名字、第二步中的权重文件路径

  • 最后拿到了多个模型的数据

4.绘制脚本

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == '__main__':
    file_list = ['pr_data/v5.csv', 'pr_data/v6.csv', 'pr_data/v7.csv', 'pr_data/v8.csv', 'pr_data/v10.csv', 'pr_data/11n.csv']
    names = ['v5', 'v6', 'v7', 'v8', 'v10', '11']
    # ap = ['0.673', '0.639', '1']

    plt.figure(figsize=(6, 6))
    for i in range(len(file_list)):
        pr_data = pd.read_csv(file_list[i], header=None)
        recall, precision = np.array(pr_data[0]), np.array(pr_data[1])

        plt.plot(recall, precision, label=f'{names[i]}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.tight_layout()
    plt.savefig('pr.png')
  • 结果:

5.将PR图与map50、map50-95绘制在一张画布上

def plot_metrics(ax, metric_col_name, y_label, color, modelname, is_pr=False):
    res_path = pr_csv_dict[modelname]
    try:
        data = pd.read_csv(res_path)
        data.columns = data.columns.str.strip()  # Remove spaces from column names

        if is_pr:
            precision = data['metrics/precision(B)'].values
            recall = data['metrics/recall(B)'].values
            ax.plot(recall, precision, label=modelname, color=color, linewidth='2')  # Set color and linewidth for PR curve
        else:
            epochs = data['epoch'].values  # epoch column
            metric_data = data[metric_col_name].values  # Get the corresponding metric column
            ax.plot(epochs, metric_data, label=modelname, color=color, linewidth='2')

    except Exception as e:
        print(f"Error reading {modelname}: {e}")

# Main function
def plot_all_metrics():
    global pr_csv_dict
    pr_csv_dict = {
        'YOLOv5': r'/v5/n/results.csv',
        'YOLOv6': r'/v6/n/results.csv',
        'YOLOv7_tiny': r'/results_v7.csv',
        'YOLOv8n': r'/yolov8-main/runs/train/re39/results.csv',
        'YOLOv10n': r'yolov10-main/runs/kitti/n/results.csv',
        'YOLO11n': r'/yolo11-1/runs/train/kitti-yolo11n/results.csv',
        'YOLOn': r'/yolo11-1/runs/train/kitti-YOLOn/results.csv',
    }

    colors = {
        'YOLOv5': '#00EE76',
        'YOLOv6': '#EEEE00',
        'YOLOv7_tiny': '#8470FF',
        'YOLOv8n': 'orange',
        'YOLOv10n': '#838B8B',
        'YOLO11n': '#00BFFF',
        'YOLOn': '#FF3030',
    }

    fig, axs = plt.subplots(1, 3, figsize=(24, 8), tight_layout=True)  # 1 row, 3 columns

    # Set global font size
    plt.rcParams.update({'font.size': 16})

    # Plot PR Curve
    file_list = ['pr_data/v5.csv', 'pr_data/v6.csv', 'pr_data/v7.csv', 'pr_data/v8.csv', 'pr_data/v10.csv', 'pr_data/11n.csv', 'pr_data/YOLO.csv']
    names = ['YOLOv5', 'YOLOv6', 'YOLOv7_tiny', 'YOLOv8n', 'YOLOv10n', 'YOLO11n', 'YOLOn']

    for i in range(len(file_list)):
        pr_data = pd.read_csv(file_list[i], header=None)
        recall, precision = np.array(pr_data[0]), np.array(pr_data[1])
        color = colors[f'{names[i]}']  # Use the corresponding color
        axs[0].plot(recall, precision, label=f'{names[i]}', color=color, linewidth='2')  # Set linewidth

    axs[0].set_xlabel('Recall', fontsize=16)
    axs[0].set_ylabel('Precision', fontsize=16)
    axs[0].set_xlim(0, 1)
    axs[0].set_ylim(0, 1)
    axs[0].legend(loc='lower right', fontsize=16)
    axs[0].spines['top'].set_linewidth(2)
    axs[0].spines['right'].set_linewidth(2)
    axs[0].spines['left'].set_linewidth(2)
    axs[0].spines['bottom'].set_linewidth(2)
    axs[0].tick_params(width=2, labelsize=14)
    axs[0].set_title('Precision-Recall Curve', fontsize=18)

    # Plot mAP@0.5
    for modelname in pr_csv_dict:
        plot_metrics(axs[1], 'metrics/mAP50(B)', 'mAP@0.5', colors[modelname], modelname)

    axs[1].set_xlabel('Epoch', fontsize=16)
    axs[1].set_ylabel('mAP@0.5', fontsize=16)
    axs[1].set_xlim(0, None)
    axs[1].set_ylim(0, 1)
    axs[1].legend(loc='lower right', fontsize=16)
    axs[1].spines['top'].set_linewidth(2)
    axs[1].spines['right'].set_linewidth(2)
    axs[1].spines['left'].set_linewidth(2)
    axs[1].spines['bottom'].set_linewidth(2)
    axs[1].tick_params(width=2, labelsize=14)
    axs[1].set_title('mAP@0.5', fontsize=18)

    # Plot mAP@0.95
    for modelname in pr_csv_dict:
        plot_metrics(axs[2], 'metrics/mAP50-95(B)', 'mAP@0.95', colors[modelname], modelname)

    axs[2].set_xlabel('Epoch', fontsize=16)
    axs[2].set_ylabel('mAP@0.95', fontsize=16)
    axs[2].set_xlim(0, None)
    axs[2].set_ylim(0, 1)
    axs[2].legend(loc='lower right', fontsize=16)
    axs[2].spines['top'].set_linewidth(2)
    axs[2].spines['right'].set_linewidth(2)
    axs[2].spines['left'].set_linewidth(2)
    axs[2].spines['bottom'].set_linewidth(2)
    axs[2].tick_params(width=2, labelsize=14)
    axs[2].set_title('mAP@0.95', fontsize=18)

    plt.subplots_adjust(wspace=0.3)  # Adjust spacing between subplots

    # Save the figure
    plt.savefig('/yolo11/yolo11-1/images/aa/diff_yolo_metrics6.png', dpi=250)#保存位置
    plt.show()

# Execute plotting
if __name__ == '__main__':
    plot_all_metrics()
  • 结果:

posted @   Frommoon  阅读(426)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
点击右上角即可分享
微信分享提示