[TensorBoard] Train and Test accuracy simultaneous tracking

 训练时的实时状态跟踪的重要性 不言而喻。

[Tensorboard] Cookbook - Tensorboard  讲解调节更新频率

 

直接上代码展示:


import numpy as np
import tensorflow as tf
from random import randint
import datetime
import os
import time

import implementation as imp

batch_size = imp.batch_size
iterations = 20001
seq_length = 40  # Maximum length of sentence

checkpoints_dir = "./checkpoints"


def getTrainBatch():
    labels = []
    arr = np.zeros([batch_size, seq_length])
    for i in range(batch_size):
        if (i % 2 == 0):
            num = randint(0, 11499)
            labels.append([1, 0])
        else:
            num = randint(12500, 23999)
            labels.append([0, 1])
        arr[i] = training_data[num]
    return arr, labels


def getTestBatch():
    labels = []
    arr = np.zeros([batch_size, seq_length])
    for i in range(batch_size):
        if (i % 2 == 0):
            num = randint(11500, 12499)
            labels.append([1, 0])
        else:
            num = randint(24000, 24999)
            labels.append([0, 1])
        arr[i] = training_data[num]
    return arr, labels
    
###############################################################################
    
# Call implementation
glove_array, glove_dict = imp.load_glove_embeddings()
training_data = imp.load_data(glove_dict)
input_data, labels, optimizer, accuracy, loss, dropout_keep_prob = imp.define_graph(glove_array)

###############################################################################

# tensorboard
train_accuracy_op = tf.summary.scalar("training_accuracy", accuracy)
tf.summary.scalar("loss", loss)
summary_op = tf.summary.merge_all()

# saver
all_saver = tf.train.Saver()

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

logdir_train = "tensorboard/" + datetime.datetime.now().strftime(
    "%Y%m%d-%H%M%S-train") + "/"
writer_train = tf.summary.FileWriter(logdir_train, sess.graph)

logdir_test = "tensorboard/" + datetime.datetime.now().strftime(
    "%Y%m%d-%H%M%S-test") + "/"
writer_test= tf.summary.FileWriter(logdir_test, sess.graph)


timePoint1 = time.time()
timePoint2 = time.time()
for i in range(iterations):
    batch_data, batch_labels           = getTrainBatch()
    batch_data_test, batch_labels_test = getTestBatch()
    
    # Set the dropout_keep_prob
    # 1.0: dropout is invalid.
    # 0.5: dropout is 0.5
    sess.run(optimizer, {input_data: batch_data, labels: batch_labels, dropout_keep_prob:0.8})
    if (i % 50 == 0):
        
        print("--------------------------------------")
        print("Iteration: ", i, round(i/iterations, 2))
        print("--------------------------------------")
        
        ##############################################################
        
        loss_value, accuracy_value, summary = sess.run(
                                    [loss, accuracy, summary_op],
                                    {input_data: batch_data,
                                     labels: batch_labels,
                                     dropout_keep_prob:1.0})
        writer_train.add_summary(summary, i)
        
        print("loss [train]", loss_value)
        print("acc  [train]", accuracy_value)
        
        ##############################################################
        
        loss_value_test, accuracy_value_test, summary_test = sess.run(
                                    [loss, accuracy, summary_op],
                                    {input_data: batch_data_test,
                                     labels: batch_labels_test,
                                     dropout_keep_prob:1.0})writer_test.add_summary(summary_test, i)
print("loss [test]", loss_value_test) print("acc [test]", accuracy_value_test) ############################################################## timePoint2 = time.time() print("Time:", round(timePoint2-timePoint1, 2)) timePoint1 = timePoint2 if (i % 10000 == 0 and i != 0): if not os.path.exists(checkpoints_dir): os.makedirs(checkpoints_dir) save_path = all_saver.save(sess, checkpoints_dir + "/trained_model.ckpt", global_step=i) print("Saved model to %s" % save_path)
sess.close()

总之,不同的summary写入不同的writer对象中。 

 

posted @ 2017-10-04 20:28  郝壹贰叁  阅读(754)  评论(0编辑  收藏  举报