torch.utils.tensorboard


TensorBoard 神经网络可视化工具


1. pytorch 官方文档解析

from torch.utils.tensorboard import SummaryWriterclass

class SummaryWriter(object)
	def __init__(
        self,
        log_dir=None, 
        comment='', 
        purge_step=None, 
        max_queue=10, 
        flush_secs=120, 
        filename_suffix=''
    )
    def add_scalar(
        self,
        tag, 
        scalar_value, 
        global_step=None,
        walltime=None, 
        new_style=False, 
        double_precision=False
    )
    def add_scalars(
        self, 
        main_tag, 
        tag_scalar_dict, 
        global_step=None, 
        walltime=None
    )
    def add_histogram(
        self,
        tag,
        values,
        global_step=None,
        bins="tensorflow",
        walltime=None,
        max_bins=None,
    )
    def add_image(
        self, 
        tag, 
        img_tensor, 
        global_step=None, 
        walltime=None, 
        dataformats="CHW"
    )
    def add_images(
        self, 
        tag, 
        img_tensor, 
        global_step=None, 
        walltime=None, 
        dataformats="NCHW"
    )
    def add_graph(
        self, 
        model, 
        input_to_model=None, 
        verbose=False, 
        use_strict_trace=True
    )

将条目直接写入 \(log\_dir\) 中的时间文件以供 \(TensorBoard\) 使用。

\(SummaryWriter\) 类提供了一个高级 \(API\),用于在给定目录中创建事件文件并向其添加摘要和事件。


__init__

  • log_dir(\(string\)):保存目录位置。默认为 runs/current_datetime_hostname,每次运行后都会更改。可自定义。
  • comment(\(string\)):不指定 log_dir, 文件夹后缀。
  • filename_suffix(\(int\)):log_dir目录中所有事件文件名后缀。

add_scalar:记录标量。

  • tag(\(string\)):标签名。
  • scalar_value(\(float、string、blobname\)):要记录的标量。
  • global_step(\(int\)):轮次。
  • new_stype(\(boolean\)):使用新样式(张亮字段)还是旧样式(\(simple\_value\) 字段)。新样式可能有更快的加载速度。

add_scalars:记录多个标量。

  • main_tag(\(string\)):多个标签名。
  • tag_scalar_dict(\(dict\)):存储标签和对应的键值。
  • global_step(\(int\)):轮次。

add_histogram:统计直方图与多分位数折线图

  • tag(\(string\)):标签名。
  • values(\(torch.Tensor、numpy.array、string、blobname\)):构建直方图的值。
  • global_step(\(int\)):轮次。
  • bins(\(string\)):取值 \(tensorflow、auto、fd\) 等。这决定如何制作垃圾箱。

add_image:显示图像

  • tag(\(string\)):标签名。
  • img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。
  • global_step(\(int\)):轮次。
  • dataformats(\(string\)):\(CHW、HWC、HW、WH\) 图像数据的格式。

add_images:批量显示图像

  • tag(\(string\)):标签名。
  • img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。
  • global_step(\(int\)):轮次。
  • dataformats(\(string\)):\(NCHW、NHWC、CHW、HWC、HW、WH\) 图像数据的格式。

add_graph:查看模型图

  • model(\(torch.nn.Model\)):模型,必须是 nn.Module
  • input_to_model(\(torch.Tensor、torch.Tensor列表\)):输出给模型的数据。
  • verbose(\(bool\)):是否打印计算图结构信息。

写完记得写 writer.close()



2. 调用方法

2.1 创建接口

writer = SummaryWriter('runs')

2.2 记录多个标量

writer.add_scalars('name', {'dic': val}, epoch)

2.3 统计直方图

writer.add_histogram('weight', self.fc.weight, epoch)

2.4 批次显示图像

writer.add_images(“Cifar10”, img_batch, epoch, 'CHW')

2.5 查看模型图

writer.add_graph(model=net,input_to_model=torch.randn(1,3, 224, 224).to(device))


来自:

https://pytorch.org/docs/stable/tensorboard.html

https://m.w3cschool.cn/article/27419536.html

posted @ 2022-08-16 15:39  做梦当财神  阅读(282)  评论(0编辑  收藏  举报