mmdetection训练的json文件可视化代码

# coding:utf-8
# 这是一个对mmdet训练结束后的json文件进行可视化的代码
# 主要是对log中记录的各种参数如:cls_loss, box_loss, mAP进行可视化
# 本方法已经经过测试,输出图片在out文件夹中
import json
import matplotlib.pyplot as plt
from collections import OrderedDict
# 可视化类
class visualize_mmdetection():
    # *args不定长形参
    def __init__(self, path, *args):
        '''

        Args:
            path: 要解析的路径,json文件
            *args: 可视化的键值,字符串型,不定长
        '''
        self.log = open(path)
        self.dict_list = list()
        self.AP_list = list()
        self.loss_dict = {}
        self.ap_dict = {}
        self.outname = path.split('/')[-1].split('.')[0]
        for i in args:
            if 'AP' in i:
                self.ap_dict[i] = list()
            else:
                self.loss_dict[i] = list()

    def load_data(self):
        for row, line in enumerate(self.log):
            # 单独处理loss 和 ap
            if 'mAP' not in line:
                info = json.loads(line)
                self.dict_list.append(info)
            if 'mAP' in line:
                info = json.loads(line)
                # print(num, info)
                self.AP_list.append(info)
        for i in range(1, len(self.dict_list)):
            for key in self.loss_dict.keys():
                # list append
                self.loss_dict[key].append(dict(self.dict_list[i])[key])
        for i in range(1, len(self.AP_list)):
            for key in self.ap_dict.keys():
                # list append
                self.ap_dict[key].append(dict(self.AP_list[i])[key])
        # 重新对key的value进行排序
        for key in self.loss_dict.keys():
            self.loss_dict[key] = list(OrderedDict.fromkeys(self.loss_dict[key]))
        for key in self.ap_dict.keys():
            self.ap_dict[key] = list(OrderedDict.fromkeys(self.ap_dict[key]))

    def show_chart(self):
        plt.rcParams.update({'font.size': 15})
        plt.figure(figsize=(20, 20))
        # 配置画图
        # 总类别数量
        num = len(self.loss_dict.keys())+len(self.ap_dict.keys())
        # 每行类别数量
        col = 2
        # 行数
        import math
        line = math.ceil(num/col)
        # 先画loss
        ind = 0
        for key in self.loss_dict.keys():
            ind += 1
            plt.subplot(line, col, ind, title= key, ylabel='loss')
            plt.plot(self.loss_dict[key])
        # 再画ap
        for key in self.ap_dict.keys():
            ind += 1
            plt.subplot(line, col, ind, title=key, ylabel='ap')
            plt.plot(self.ap_dict[key])
        # import time
        # now_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
        plt.suptitle((self.outname + "\n training result"), fontsize=30)
        plt.savefig((self.outname + '_result.png'))


if __name__ == '__main__':
    # 输入你想要监视的键值,如在json文件中的保存的'loss_cls', 'loss_bbox', 'loss_obj', 'loss'等值
    # 测试方法
    # 输出的图片保存在本代码同级目录下,以json文件的前缀+.png命名
    x = visualize_mmdetection('../../work_dirs/20220513-ct/20220517_083215.log.json',
                              'loss_cls', 'loss_bbox', 'loss_obj', 'loss',
                              'mAP', 'AP50')
    x.load_data()
    x.show_chart()

 

posted @ 2022-06-02 12:43  海_纳百川  阅读(650)  评论(0编辑  收藏  举报
本站总访问量