wts格式转换(torch)

若使用tensorrt加载wts格式,需将模型训练的pt、pth、ckpt等格式权重转换为wts,其代码细节如下:

 

def checkpint2wts(model, wts_file):
    '''
    model:模型,需要权重
    wts_file:保存wts权重路径,如result.wts
    '''
    import struct
    model_state_dict = model.state_dict()  # 此处需要根据模型情况获得state_dict
    with open(wts_file, 'w') as f:
        f.write('{}\n'.format(len(model_state_dict.keys())))
        for k, v in model_state_dict.items():
            vr = v.reshape(-1).cpu().numpy()
            f.write('{} {} '.format(k, len(vr)))
            for vv in vr:
                f.write(' ')
                f.write(struct.pack('>f', float(vv)).hex())
            f.write('\n')

 

posted @ 2022-06-23 12:45  tangjunjun  阅读(636)  评论(0编辑  收藏  举报
https://rpc.cnblogs.com/metaweblog/tangjunjun