plot_pr_curve

plot_pr_curve2

 

def plot_pr_curve2(px, py, conf_p,conf_r,ap, save_dir='.', names=()):
    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
    py = np.stack(py, axis=1)

    if 0 < len(names) < 21:  # show mAP in legend if < 10 classes
        for i, y in enumerate(py.T):
            ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0])  # plot(recall, precision)
    else:
        ax.plot(px, py, linewidth=1, color='grey')  # plot(recall, precision)

    print("==========Recall==================")
    print(px)

    print("-----\n\n\nPrecision:")
    #$print(py)
    #for index_y in py:
    #    print(index_y)
    
    index_y1 = []
    index_y2=[]
    index_ym=[]
    for index_y in py:
        index_y1.append(index_y[0])
        index_y2.append(index_y[1])
        index_ym.append((index_y[0] + index_y[1])*0.5)
    
    print(index_y1)
    print("-----\n\n\n")
    print(index_y2)
    print("-----\n\n\n")
    print(index_ym)
    print("-------------conf_p:")
    print(conf_p)
    #print("-------------conf_r:")
    #print(conf_r)
    print("---------------------------------------------------------")
    ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
    ax.set_xlabel('Recall')
    ax.set_ylabel('Precision')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)

 

 

#######################

posted @ 2022-03-11 21:21  西北逍遥  阅读(124)  评论(0编辑  收藏  举报