tensorflow add_to_collection用法

训练代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np
import argparse


def dense_to_one_hot(input_data, class_num):
    data_num = input_data.shape[0]
    index_offset = np.arange(data_num) * class_num
    labels_one_hot = np.zeros((data_num, class_num))
    labels_one_hot.flat[index_offset + input_data.ravel()] = 1
    return labels_one_hot


def build_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    args = parser.parse_args()
    return args


p = build_parser()
origin = np.genfromtxt(p.data_path, delimiter=',')

data = origin[:, 0:2]
labels = origin[:, 2]


learning_rate = 0.001
training_epochs = 5000
display_step = 1

n_features = 2
n_class = 2
x = tf.placeholder(tf.float32, [None, n_features], "input")
y = tf.placeholder(tf.float32, [None, n_class])

W = tf.Variable(tf.zeros([n_features, n_class]), name="w")
b = tf.Variable(tf.zeros([n_class]), name="b")

scores = tf.nn.xw_plus_b(x, W, b, name='scores')
pred_proba = tf.nn.softmax(scores, name="pred_proba")

cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=scores, labels=y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

saver = tf.train.Saver()
tf.add_to_collection('pred_proba', pred_proba)
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_epochs):
        result_pred_proba, _, c = sess.run([pred_proba, optimizer, cost],
                                           feed_dict={x: data, y: dense_to_one_hot(labels.astype(int), 2)})
        if epoch % 100 == 0:
            print(c)
    saver.save(sess, p.model_path)
    print("Optimization Finished!")

推理代码:

# coding: utf-8
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np
import argparse


def build_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, required=True)
    args = parser.parse_args()
    return args

p = build_parser()

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph(p.model_path + ".meta")
    new_saver.restore(sess, p.model_path)
    pred_proba = tf.get_collection('pred_proba')[0]
    graph = tf.get_default_graph()
    input_x = graph.get_operation_by_name('input').outputs[0]
    r = sess.run(pred_proba, feed_dict={input_x: np.array([[0.6211,5]])})
    print(r)
    print(0 if r[0][0] > r[0][1] else 1)

参考资料

TensorFlow 模型保存/载入的两种方法

posted on 2018-02-06 21:21  荷楠仁  阅读(685)  评论(0编辑  收藏  举报

导航