将不同yolo的PR图绘制在同一张图上
步骤:
- 首先需要将
1.修改PR绘制源码--保存绘制数据
- yolo11代码路径:/ultralytics/utils/metrics.py
def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None):
"""Plots a precision-recall curve."""
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = np.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
# 保存每个类别曲线的xy
# if 'Pose' in str(save_dir):
# with open(f'pr_data/11n/{names[i]}_pose.csv', 'w+') as f:
# for px_v, y_v in zip(px, y):
# f.write(f'{px_v},{y_v}\n')
# else:
# with open(f'pr_data/11n/{names[i]}_Box.csv', 'w+') as f:
# for px_v, y_v in zip(px, y):
# f.write(f'{px_v},{y_v}\n')
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
# 保存all这条曲线的xy
with open('pr_data/11n.csv', 'w+') as f:
for px_v, mean_y_v in zip(px, py.mean(1)):
f.write(f'{px_v},{mean_y_v}\n')
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title("Precision-Recall Curve")
fig.savefig(save_dir, dpi=250)
plt.close(fig)
if on_plot:
on_plot(save_dir)
2.运行val.py
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
if __name__ == '__main__':
model = YOLO('/yolo11/yolo11-2/runs/train/kitti-yolo11/weights/best.pt') # 选择训练好的权重路径
model.val(data='/Object_detection/LS/yolo11/yolo11-1/ultralytics/cfg/datasets/kitti.yaml',
split='val', # split可以选择train、val、test 根据自己的数据集情况来选择.
imgsz=640,
batch=16,
project='runs/val',
name='exp',
)
3.运行结果
- 同样的运行需要绘制的不同模型,只需更改第一步中的保存的名字、第二步中的权重文件路径
- 最后拿到了多个模型的数据
4.绘制脚本
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
if __name__ == '__main__':
file_list = ['pr_data/v5.csv', 'pr_data/v6.csv', 'pr_data/v7.csv', 'pr_data/v8.csv', 'pr_data/v10.csv', 'pr_data/11n.csv']
names = ['v5', 'v6', 'v7', 'v8', 'v10', '11']
# ap = ['0.673', '0.639', '1']
plt.figure(figsize=(6, 6))
for i in range(len(file_list)):
pr_data = pd.read_csv(file_list[i], header=None)
recall, precision = np.array(pr_data[0]), np.array(pr_data[1])
plt.plot(recall, precision, label=f'{names[i]}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.tight_layout()
plt.savefig('pr.png')
- 结果:
5.将PR图与map50、map50-95绘制在一张画布上
def plot_metrics(ax, metric_col_name, y_label, color, modelname, is_pr=False):
res_path = pr_csv_dict[modelname]
try:
data = pd.read_csv(res_path)
data.columns = data.columns.str.strip() # Remove spaces from column names
if is_pr:
precision = data['metrics/precision(B)'].values
recall = data['metrics/recall(B)'].values
ax.plot(recall, precision, label=modelname, color=color, linewidth='2') # Set color and linewidth for PR curve
else:
epochs = data['epoch'].values # epoch column
metric_data = data[metric_col_name].values # Get the corresponding metric column
ax.plot(epochs, metric_data, label=modelname, color=color, linewidth='2')
except Exception as e:
print(f"Error reading {modelname}: {e}")
# Main function
def plot_all_metrics():
global pr_csv_dict
pr_csv_dict = {
'YOLOv5': r'/v5/n/results.csv',
'YOLOv6': r'/v6/n/results.csv',
'YOLOv7_tiny': r'/results_v7.csv',
'YOLOv8n': r'/yolov8-main/runs/train/re39/results.csv',
'YOLOv10n': r'yolov10-main/runs/kitti/n/results.csv',
'YOLO11n': r'/yolo11-1/runs/train/kitti-yolo11n/results.csv',
'YOLOn': r'/yolo11-1/runs/train/kitti-YOLOn/results.csv',
}
colors = {
'YOLOv5': '#00EE76',
'YOLOv6': '#EEEE00',
'YOLOv7_tiny': '#8470FF',
'YOLOv8n': 'orange',
'YOLOv10n': '#838B8B',
'YOLO11n': '#00BFFF',
'YOLOn': '#FF3030',
}
fig, axs = plt.subplots(1, 3, figsize=(24, 8), tight_layout=True) # 1 row, 3 columns
# Set global font size
plt.rcParams.update({'font.size': 16})
# Plot PR Curve
file_list = ['pr_data/v5.csv', 'pr_data/v6.csv', 'pr_data/v7.csv', 'pr_data/v8.csv', 'pr_data/v10.csv', 'pr_data/11n.csv', 'pr_data/YOLO.csv']
names = ['YOLOv5', 'YOLOv6', 'YOLOv7_tiny', 'YOLOv8n', 'YOLOv10n', 'YOLO11n', 'YOLOn']
for i in range(len(file_list)):
pr_data = pd.read_csv(file_list[i], header=None)
recall, precision = np.array(pr_data[0]), np.array(pr_data[1])
color = colors[f'{names[i]}'] # Use the corresponding color
axs[0].plot(recall, precision, label=f'{names[i]}', color=color, linewidth='2') # Set linewidth
axs[0].set_xlabel('Recall', fontsize=16)
axs[0].set_ylabel('Precision', fontsize=16)
axs[0].set_xlim(0, 1)
axs[0].set_ylim(0, 1)
axs[0].legend(loc='lower right', fontsize=16)
axs[0].spines['top'].set_linewidth(2)
axs[0].spines['right'].set_linewidth(2)
axs[0].spines['left'].set_linewidth(2)
axs[0].spines['bottom'].set_linewidth(2)
axs[0].tick_params(width=2, labelsize=14)
axs[0].set_title('Precision-Recall Curve', fontsize=18)
# Plot mAP@0.5
for modelname in pr_csv_dict:
plot_metrics(axs[1], 'metrics/mAP50(B)', 'mAP@0.5', colors[modelname], modelname)
axs[1].set_xlabel('Epoch', fontsize=16)
axs[1].set_ylabel('mAP@0.5', fontsize=16)
axs[1].set_xlim(0, None)
axs[1].set_ylim(0, 1)
axs[1].legend(loc='lower right', fontsize=16)
axs[1].spines['top'].set_linewidth(2)
axs[1].spines['right'].set_linewidth(2)
axs[1].spines['left'].set_linewidth(2)
axs[1].spines['bottom'].set_linewidth(2)
axs[1].tick_params(width=2, labelsize=14)
axs[1].set_title('mAP@0.5', fontsize=18)
# Plot mAP@0.95
for modelname in pr_csv_dict:
plot_metrics(axs[2], 'metrics/mAP50-95(B)', 'mAP@0.95', colors[modelname], modelname)
axs[2].set_xlabel('Epoch', fontsize=16)
axs[2].set_ylabel('mAP@0.95', fontsize=16)
axs[2].set_xlim(0, None)
axs[2].set_ylim(0, 1)
axs[2].legend(loc='lower right', fontsize=16)
axs[2].spines['top'].set_linewidth(2)
axs[2].spines['right'].set_linewidth(2)
axs[2].spines['left'].set_linewidth(2)
axs[2].spines['bottom'].set_linewidth(2)
axs[2].tick_params(width=2, labelsize=14)
axs[2].set_title('mAP@0.95', fontsize=18)
plt.subplots_adjust(wspace=0.3) # Adjust spacing between subplots
# Save the figure
plt.savefig('/yolo11/yolo11-1/images/aa/diff_yolo_metrics6.png', dpi=250)#保存位置
plt.show()
# Execute plotting
if __name__ == '__main__':
plot_all_metrics()
- 结果:
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理