Tensorflow的ckpt转为npy格式的代码

由于无法把非神经网络分类器在tensorflow框架中搭建,想把神经网络的输出概率用到其他的非神经网络的分类其中,这就需要把神经网络中保存的参数提取出来。由于神经网络是基于图的计算,有自己的保存方式,我们不能随意保存和提取其中的参数。下面是使用一个pywrap_tensorflow的工具去读ckpt文件,将其中保存的信息读取出来,再保存到我们想要的格式。

from tensorflow.python import pywrap_tensorflow
import numpy as np
checkpoint_path="/tmp/prop/softmax_out.ckpt"
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
param =[]

for key in var_to_shape_map:
    print ("tensor_name",key)
    param.append(reader.get_tensor(key))

np.save('dnnout.npy',param)
posted @ 2018-08-30 11:07  Siucaan  阅读(978)  评论(0编辑  收藏  举报