mmdetection 生成c++ 的anchor头文件

import os
import os.path as osp

import numpy as np
from mmcv import Config
from mmdet.models import build_detector
import math
import argparse
import pickle


def save_cplus_h(multi_level_anchors, savefilepath):
    """
    multi_level_anchors.shape: [[box1,box2, ...], [], ...]
    """
    _template = '''
    
#ifndef _ANCHOR_H
#define _ANCHOR_H

#include "stdlib.h"
#include <vector>

/*
 * 目标检测的anchor框 头文件
 * */

static std::vector<std::vector<std::vector<float>>> ssd_anchor_all_levels = {
//        // stride 8
//        {
//                {-10., 0, -10., 20.},      // x1 y1 x2 y2
//                {-10., 0, -10., 20.},
//                {-10., 0, -10., 20.},
//        },
//
//        // stride 16
//        {
//                {-10., 0, -10., 20.},      // x1 y1 x2 y2
//                {-10., 0, -10., 20.},
//                {-10., 0, -10., 20.},
//        },
//
//        // stride 32
//        {
//                {-10., 0, -10., 20.},      // x1 y1 x2 y2
//                {-10., 0, -10., 20.},
//                {-10., 0, -10., 20.},
//        },
//
//        //...
        @here
};


#endif //_ANCHOR_H
    '''
    with open(savefilepath, 'w', encoding='utf-8') as f:
        m = ''
        for anchors in multi_level_anchors:
            # [N,4]
            a = '{\n'
            for x1, y1, x2, y2 in anchors:
                line = '{' + f'{x1}, {y1}, {x2}, {y2}' + '},\n'
                a += line
            a += '},\n'
        m += a
        _template = _template.replace('@here', m)
        f.write(_template)
    return


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMDetection models to ONNX')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--output-file', type=str, default='anchor.npy')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[3, 224, 320],
        help='input image size (CHW)')
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    CONFIG = args.config

    assert args.output_file is not None
    INPUT_C, INPUT_H, INPUT_W = tuple(args.shape)

    cfg = Config.fromfile(CONFIG)
    model = build_detector(cfg.model)

    model.eval()
    assert hasattr(model, 'bbox_head')
    strides_tuple_list = model.bbox_head.anchor_generator.strides
    featmap_sizes = []
    c, h, w = INPUT_C, INPUT_H, INPUT_W
    for sh, sw in strides_tuple_list:
        featmap_sizes.append((math.ceil(h / sh + 0.5), math.ceil(w / sw + 0.5)))
        print([sh, sw], featmap_sizes)
    multi_level_anchors = model.bbox_head.anchor_generator.grid_anchors(featmap_sizes, 'cpu')

    total_anchors = np.concatenate(multi_level_anchors, axis=0)

    print(total_anchors.shape)
    print(total_anchors[:10])

    wh = np.array([[w, h] * 2])
    total_anchors_norm = total_anchors / wh

    ###################### 转bin #######################################################
    anchor_num = total_anchors.shape[0]
    total_anchors_bin_file = total_anchors_norm.reshape((1, 1, -1))
    # todo:  bbox_coder
    target_stds = cfg.model.bbox_head.bbox_coder.target_stds
    target_stds = np.array(target_stds * anchor_num).reshape((1, 1, -1))
    # [1,2,anchor_nums * 4]
    total_anchors_bin_file_res = np.concatenate((total_anchors_bin_file, target_stds), axis=1)
    _savefilepath_bin = f'{args.output_file[:-4]}.bin'
    total_anchors_bin_file_res.astype('float32').tofile(_savefilepath_bin)

    # .npy for python caffe
    _savefilepath_npy = f'{args.output_file[:-4]}.npy'
    np.save(_savefilepath_npy, total_anchors_norm)
    print(_savefilepath_bin)
    print(_savefilepath_npy)

    # 单独保存每个尺度的anchor
    # [[box1,box2, ...], [], ...]
    _savefilepath_pkl_per_levels = f'{args.output_file[:-4]}_per_level.pkl'
    f = open(_savefilepath_pkl_per_levels, 'wb')
    multi_level_anchors = [i.cpu().numpy() for i in multi_level_anchors]
    pickle.dump(multi_level_anchors, f)
    f.close()
    print(_savefilepath_pkl_per_levels)

    # 保存c++ 的头文件
    _savefilepath_h_per_levels = f'{args.output_file[:-4]}_per_level.h'
    save_cplus_h(multi_level_anchors, _savefilepath_h_per_levels)
    print(_savefilepath_h_per_levels)
    print('done.')

 

posted @ 2022-10-18 15:10  dangxusheng  阅读(52)  评论(0编辑  收藏  举报