pytorch权重转wts格式-用于tensorrt权重加载
若使用tensorrt加载wts格式,需将模型训练的pt、pth、ckpt等格式权重转换为wts。
我只记录此代码,供读者使用。
def checkpint2wts(model, wts_file): ''' model:模型,需要权重 wts_file:保存wts权重路径,如result.wts ''' import struct model_state_dict = model.state_dict() # 此处需要根据模型情况获得state_dict with open (wts_file, 'w' ) as f: f.write( '{}\n' . format ( len (model_state_dict.keys()))) for k, v in model_state_dict.items(): vr = v.reshape( - 1 ).cpu().numpy() f.write( '{} {} ' . format (k, len (vr))) for vv in vr: f.write( ' ' ) f.write(struct.pack( '>f' , float (vv)). hex ()) f.write( '\n' ) |
本文来自博客园,作者:海_纳百川,转载请注明原文链接:https://www.cnblogs.com/chentiao/p/16668280.html,如有侵权联系删除