模型部署: 从tensorflow-ckpt和h5中获取模型权重并存储为bin文件

为什么要转储权重?

转储权重允许我们自行实现模型的推理部分,以便更细粒度地控制推理流程,应用各类推理加速方法,减少推理时间,降低推理的内存占用。对于将模型部署在计算资源有限,又要求高实时性推理的移动或iot设备上至关重要。

从ckpt中转储权重和tensor信息

依赖:tensorflow:2.x

from tensorflow.python.training import py_checkpoint_reader
import numpy as np
import sys
import os

if __name__ == "__main__":
    np.set_printoptions(threshold=np.inf)
    # ckpt路径,根据需要修改,注意没有后缀。
    ckpt_path ="ckpt-01"
    # 权重二进制文件和tensor信息输出路径,可以不修改
    output_path = './'
    # 权重文件和二进制文件名,可以不修改。
    tensor_info_name = "tensor_info.txt"
    tensor_bin_name = "tensor.bin"
    # 如果有传入参数,则用传入参数更新ckpt-path和output-path
    if len(sys.argv) > 1:
        try:
            ckpt_path = sys.argv[1]
            output_path = sys.argv[2]
        except KeyError:
            print("need 2 params: ckpt_path(without suffix)  output_path(dir)")
            sys.exit()
    # 读取ckpt,tensor的shape和类型
    reader = py_checkpoint_reader.NewCheckpointReader(ckpt_path)
    shape_map = reader.get_variable_to_shape_map()
    dtype_map = reader.get_variable_to_dtype_map()

    file_info = open(os.path.join(output_path, tensor_info_name), 'w')
    bin_file = open(os.path.join(output_path, tensor_bin_name), 'wb')
    for k, v in sorted(shape_map.items()):
        # shape_map中包含了计算图,无法得到权重和info信息,通过异常处理跳过。
        try:
            tensor_type = str(dtype_map[k])
            tensor_shape = str(v).replace('\n', '')
            # 将权重tensor压到一维,以便存储
            tensor=reader.get_tensor(k).reshape(-1)
            if tensor_shape=="[]":
                tensor_shape="[{}]".format(str(len(tensor)))
            # 这里,请根据tensor的类型补充格式,float32和int64是最常见的权重格式,一般来说无需调整。
            if tensor_type == "<dtype: 'float32'>":
                tensor_type = "float32"
                file_info.write(str(k).replace('\n', '') + " " + tensor_type.replace('\n', '') + " " + tensor_shape + " " + '\n')
                bin_file.write(bytearray(tensor.astype(np.float32)))
            elif tensor_type == "<dtype: 'int64'>":
                tensor_type = "int64"
                file_info.write(str(k).replace('\n', '') + " " + tensor_type.replace('\n', '') + " " + tensor_shape + " " + '\n')
                bin_file.write(bytearray(tensor.astype(np.int64)))
            else:
                raise Exception("exist unsupported type, justify code according to your need")
        except:
             pass
    file_info.close()
    bin_file.close()
    print("save to {}{} {}{}".format(output_path, tensor_info_name, output_path, tensor_bin_name))

从h5中转储权重和tensor信息

依赖:h5py(存储了h5权重,一般都会安装这个依赖)

import os
import sys
import h5py
import numpy as np

if __name__=="__main__":
    # h5权重文件路径,根据需要修改。
    h5file_path="model-680000.h5"
    output_path = "./"
    if len(sys.argv) > 1:
        try:
            h5file_path = sys.argv[1]
            output_path = sys.argv[2]
        except KeyError:
            print("need 2 params: h5_path  output_path(dir)")
            sys.exit()
    with h5py.File(h5file_path) as file:
        # 在这里修改权重二进制文件和info文件名
        info_file = open(os.path.join(output_path, "tensor_info_h5.txt"), 'w')
        bin_file = open(os.path.join(output_path, 'tensor_h5.bin'), 'wb')
        def dump_info(group):
            for k, v in group.items():
                if type(v)==h5py._hl.group.Group:
                    dump_info(v)
                else:
                    info_file.write("{} {} {}\n".format(str(v.name), str(v.dtype), str(v.shape).replace('(','[').replace(')',']')))
                    bin_file.write(bytearray(np.array(v[:]).reshape(-1).astype(np.float32)))
        dump_info(file)
        info_file.close()
        bin_file.close()

posted @   8bit布丁  阅读(838)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列:向量存储与相似性搜索在 .NET 中的实现
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
阅读排行:
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
点击右上角即可分享
微信分享提示