:)模型保存为单一个pb文件
模型保存为单一个pb文件
背景
参考连接: https://www.yuque.com/g/jesse-ztr2k/nkke46/ss4rlv/collaborator/join?token=XUVZNORisVWEWyst#
注意有些时候需要添加一个pb文件。 而不是tensorflow 提供的save 方法生成的一个目录里面包含了若干pb文件。
load时候直接填写这个目录即可。 但是有些时候需要合成一个pb文件。
tf2生成pb 目录描述
1 目录结构
-assets
-variables
-variables.data-00000-of-00001
-variables.index
-saved_model.pb
2 作用
其中 variables 记录模型参数 , pb文件记录模型结构
tf2 都是保存的 权重和 结构分开的, 如果需要兼容tf V1的代码,即导入一个pb文件,就需要 1 )保存常量计算图 2)frozen graph pb格式。
tf1 生成pb脚本
环境准备:
tensorflow==1.15, tf-slim==1.1.0
https://github.com/tensorflow/models/tree/master/research/slim
注意 一定在tf v1 环境下生成pb
1 import cv2 2 import numpy as np 3 import tensorflow as tf 4 import os 5 from tensorflow.python.framework import graph_util 6 7 # 参考连接 https://blog.csdn.net/tensorflowforum/article/details/112352764 代码 8 # 参考连接 参数详解:https://blog.csdn.net/weixin_43529465/article/details/124721583 9 # https://blog.csdn.net/rain6789/article/details/78754516 10 11 class SingleCnn(tf.keras.Model): 12 def __init__(self): 13 super(SingleCnn, self).__init__() 14 # filters=1 卷积核数目,相当于卷积核的channel 15 self.conv = tf.keras.layers.Conv2D(filters=1, 16 kernel_size=[1, 1], 17 # valid表示不填充, same表示合理填充 18 padding='valid', 19 # data_format='channels_last',-> 表示HWC,输入可以定义批次 20 data_format='channels_last', 21 use_bias=False, 22 kernel_initializer=tf.keras.initializers.he_uniform(seed=None), 23 name="conv") 24 25 def call(self, inputs): 26 x = self.conv(inputs) 27 return x 28 if __name__ == "__main__": 29 # 构建场景输入数据 30 31 # images=tf.random.uniform((1, 300, 300, 3)) 32 33 # 图像数据 34 imagefile = r"catanddog\cat\5.JPG" 35 img = cv2.imread(imagefile) 36 img = cv2.resize(img, (64, 64)) 37 img = np.expand_dims(img, axis=0) 38 print(img.shape, type(img), img.dtype) 39 40 # 未量化的model不支持int32和int8 41 # img = img.astype(np.int32) 42 img = tf.convert_to_tensor(img, np.float32) 43 print(img.shape, type(img), img.dtype) 44 singlecnn = SingleCnn() 45 46 output = singlecnn(img) 47 print(output.shape, type(output)) 48 print(output[0][2:10][2:6]) 49 # =========== ckpt保存 with session的写法tf2 已不再使用 =========== 50 # with tf.Session(graph=tf.Graph()) as sess: 51 # constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store']) 52 53 # 保存参考 https://zhuanlan.zhihu.com/p/146243327 54 # save_format='tf' 代表保存pb 55 # singlecnn.save('./pbmodel/singlecnn', save_format='tf') 56 # tf.saved_model.save(singlecnn, './pbmodel/singlecnn') 57 tf.keras.models.save_model(singlecnn, './pbmodel/singlecnn_0', 58 save_format="tf", 59 include_optimizer=False, save_traces=False) 60 61 # 加载模型 验证可以加载 62 new_model = tf.keras.models.load_model('./pbmodel/singlecnn_0', compile=False) 63 # new_model = tf.saved_model.load('./pbmodel/singlecnn_0') 64 # output_ = new_model(img) 65 # # print(output_.shape, output_[0][2:6][2:6]) 66 # print(output_.shape) 67 # 68 # 查看结构 69 new_model.summary() 70 71 # print("----------------") 72 # # 加载模型 73 # saved_model = tf.saved_model.load('./pbmodel/singlecnn_0') 74 # # 将模型转换为pb格式 还是目录方法。 75 # converter = tf.saved_model.save(saved_model, "model.pb") 76 77 def change_pb(pretrained_model): 78 """tf v1 选用tf1 跑这个脚本生成pb""" 79 from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 80 # 重点 81 # Convert Keras model to ConcreteFunction 82 # MobileNet is a function 83 full_model = tf.function(lambda x: pretrained_model(x)) 84 85 # 指定shape和dtype对tf function进行重新追踪 86 full_model = full_model.get_concrete_function( 87 tf.TensorSpec(pretrained_model.inputs[0].shape, pretrained_model.inputs[0].dtype)) 88 89 # Get frozen ConcreteFunction,将计算图中的变量及其取值通过常量的方式保存 90 frozen_func = convert_variables_to_constants_v2(full_model) 91 frozen_func.graph.as_graph_def() 92 93 layers = [op.name for op in frozen_func.graph.get_operations()] 94 print("-" * 50) 95 print("Frozen model layers: ") 96 for layer in layers: 97 print(layer) 98 99 print("-" * 50) 100 print("Frozen model inputs: ") 101 print(frozen_func.inputs) 102 print("Frozen model outputs: ") 103 print(frozen_func.outputs) 104 105 # Save frozen graph from frozen ConcreteFunction to hard drive 106 # as_text: If True, writes the graph as an ASCII proto; otherwise, The graph is written as a text proto 107 tf.io.write_graph(graph_or_graph_def=frozen_func.graph, 108 logdir="./frozen_models", 109 name="frozen_graph.pb", 110 as_text=True) 111 112 113 change_pb(new_model)
python download_and_convert_data.py --dataset_name=flowers --dataset_dir="tmp/dataset"