代码改变世界

tensorflow(3)可视化,日志,调试

2017-12-30 17:08  撞破南墙  阅读(754)  评论(0编辑  收藏  举报

可视化

添加变量
tf.summary.histogram( "weights1", weights1) # 可视化观看变量
还有添加图像和音频、

常量
tf.summary.scalar('x', x)

添加embedding

python

def checkpoint(sess):
    # Output directory for models and summaries
    timestamp = str(int(time.time()))
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
    print("Writing to {}\n".format(out_dir))
    # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)
    saver.save(sess, os.path.join(out_dir, "model.ckpt"), 1)

合并
merged = tf.summary.merge_all()

添加视图
writer=tf.summary.FileWriter('log',sess.graph)

注意!如果Chrome打不开视图,可以尝试换个浏览器。

-- CMD下启动 去 127.0.0.1:6006 查看

tensorboard.exe --logdir=log

调试

例如下面这种简陋的方式,将要观察的变量放进去再取出来就行。
当然有更好的办法例如
https://www.cnblogs.com/huangshiyu13/p/6721805.html

code

summary, _prediction1, _logits1, _prediction_argmax1, _Y_argmax1, _X2, _correct_pred1, _accuracy1, _loss_op1 = \
    sess.run([merged, prediction, logits, prediction_argmax, Y_argmax, X, correct_pred, accuracy, loss_op],
             feed_dict={X: batch_x, Y: batch_y})

日志

复杂的变量打印是不够的,记录到本地就好了

code

handler = logging.handlers.RotatingFileHandler(LOG_FILE, maxBytes=1024 * 1024*1024, backupCount=5)  # 实例化handler
#fmt = '%(asctime)s - %(filename)s:%(lineno)s - %(name)s - %(message)s'
fmt = '%(asctime)s - %(filename)s:%(lineno)s - %(message)s'

formatter = logging.Formatter(fmt)  # 实例化formatter
handler.setFormatter(formatter)  # 为handler添加formatter

logger = logging.getLogger('tst')  # 获取名为tst的logger
logger.addHandler(handler)  # 为logger添加handler
logger.setLevel(logging.DEBUG)
logger.info('==================================')

def myLog(obj):
logger.info("obj begin===========================" + str(len(obj)))
for l1 in obj:
    # print(l1)
    logger.info(l1)

参考
https://www.cnblogs.com/huangshiyu13/p/6721805.html