模型部署: 从tensorflow-ckpt和h5中获取模型权重并存储为bin文件
为什么要转储权重?
转储权重允许我们自行实现模型的推理部分,以便更细粒度地控制推理流程,应用各类推理加速方法,减少推理时间,降低推理的内存占用。对于将模型部署在计算资源有限,又要求高实时性推理的移动或iot设备上至关重要。
从ckpt中转储权重和tensor信息
依赖:tensorflow:2.x
from tensorflow.python.training import py_checkpoint_reader
import numpy as np
import sys
import os
if __name__ == "__main__":
np.set_printoptions(threshold=np.inf)
# ckpt路径,根据需要修改,注意没有后缀。
ckpt_path ="ckpt-01"
# 权重二进制文件和tensor信息输出路径,可以不修改
output_path = './'
# 权重文件和二进制文件名,可以不修改。
tensor_info_name = "tensor_info.txt"
tensor_bin_name = "tensor.bin"
# 如果有传入参数,则用传入参数更新ckpt-path和output-path
if len(sys.argv) > 1:
try:
ckpt_path = sys.argv[1]
output_path = sys.argv[2]
except KeyError:
print("need 2 params: ckpt_path(without suffix) output_path(dir)")
sys.exit()
# 读取ckpt,tensor的shape和类型
reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
shape_map = reader.get_variable_to_shape_map()
dtype_map = reader.get_variable_to_dtype_map()
file_info = open(os.path.join(output_path, tensor_info_name), 'w')
bin_file = open(os.path.join(output_path, tensor_bin_name), 'wb')
for k, v in sorted(shape_map.items()):
# shape_map中包含了计算图,无法得到权重和info信息,通过异常处理跳过。
try:
tensor_type = str(dtype_map[k])
tensor_shape = str(v).replace('\n', '')
# 将权重tensor压到一维,以便存储
tensor=reader.get_tensor(k).reshape(-1)
if tensor_shape=="[]":
tensor_shape="[{}]".format(str(len(tensor)))
# 这里,请根据tensor的类型补充格式,float32和int64是最常见的权重格式,一般来说无需调整。
if tensor_type == "<dtype: 'float32'>":
tensor_type = "float32"
file_info.write(str(k).replace('\n', '') + " " + tensor_type.replace('\n', '') + " " + tensor_shape + " " + '\n')
bin_file.write(bytearray(tensor.astype(np.float32)))
elif tensor_type == "<dtype: 'int64'>":
tensor_type = "int64"
file_info.write(str(k).replace('\n', '') + " " + tensor_type.replace('\n', '') + " " + tensor_shape + " " + '\n')
bin_file.write(bytearray(tensor.astype(np.int64)))
else:
raise Exception("exist unsupported type, justify code according to your need")
except:
pass
file_info.close()
bin_file.close()
print("save to {}{} {}{}".format(output_path, tensor_info_name, output_path, tensor_bin_name))
从h5中转储权重和tensor信息
依赖:h5py(存储了h5权重,一般都会安装这个依赖)
import os
import sys
import h5py
import numpy as np
if __name__=="__main__":
# h5权重文件路径,根据需要修改。
h5file_path="model-680000.h5"
output_path = "./"
if len(sys.argv) > 1:
try:
h5file_path = sys.argv[1]
output_path = sys.argv[2]
except KeyError:
print("need 2 params: h5_path output_path(dir)")
sys.exit()
with h5py.File(h5file_path) as file:
# 在这里修改权重二进制文件和info文件名
info_file = open(os.path.join(output_path, "tensor_info_h5.txt"), 'w')
bin_file = open(os.path.join(output_path, 'tensor_h5.bin'), 'wb')
def dump_info(group):
for k, v in group.items():
if type(v)==h5py._hl.group.Group:
dump_info(v)
else:
info_file.write("{} {} {}\n".format(str(v.name), str(v.dtype), str(v.shape).replace('(','[').replace(')',']')))
bin_file.write(bytearray(np.array(v[:]).reshape(-1).astype(np.float32)))
dump_info(file)
info_file.close()
bin_file.close()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人