plot_pr_curve
plot_pr_curve2
def plot_pr_curve2(px, py, conf_p,conf_r,ap, save_dir='.', names=()): fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) if 0 < len(names) < 21: # show mAP in legend if < 10 classes for i, y in enumerate(py.T): ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) else: ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) print("==========Recall==================") print(px) print("-----\n\n\nPrecision:") #$print(py) #for index_y in py: # print(index_y) index_y1 = [] index_y2=[] index_ym=[] for index_y in py: index_y1.append(index_y[0]) index_y2.append(index_y[1]) index_ym.append((index_y[0] + index_y[1])*0.5) print(index_y1) print("-----\n\n\n") print(index_y2) print("-----\n\n\n") print(index_ym) print("-------------conf_p:") print(conf_p) #print("-------------conf_r:") #print(conf_r) print("---------------------------------------------------------") ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean()) ax.set_xlabel('Recall') ax.set_ylabel('Precision') ax.set_xlim(0, 1) ax.set_ylim(0, 1) plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)
#######################
QQ 3087438119