torch.utils.tensorboard
TensorBoard 神经网络可视化工具
1. pytorch 官方文档解析
from torch.utils.tensorboard import SummaryWriterclass
class SummaryWriter(object)
def __init__(
self,
log_dir=None,
comment='',
purge_step=None,
max_queue=10,
flush_secs=120,
filename_suffix=''
)
def add_scalar(
self,
tag,
scalar_value,
global_step=None,
walltime=None,
new_style=False,
double_precision=False
)
def add_scalars(
self,
main_tag,
tag_scalar_dict,
global_step=None,
walltime=None
)
def add_histogram(
self,
tag,
values,
global_step=None,
bins="tensorflow",
walltime=None,
max_bins=None,
)
def add_image(
self,
tag,
img_tensor,
global_step=None,
walltime=None,
dataformats="CHW"
)
def add_images(
self,
tag,
img_tensor,
global_step=None,
walltime=None,
dataformats="NCHW"
)
def add_graph(
self,
model,
input_to_model=None,
verbose=False,
use_strict_trace=True
)
将条目直接写入 \(log\_dir\) 中的时间文件以供 \(TensorBoard\) 使用。
\(SummaryWriter\) 类提供了一个高级 \(API\),用于在给定目录中创建事件文件并向其添加摘要和事件。
__init__
:
log_dir
(\(string\)):保存目录位置。默认为runs/current_datetime_hostname
,每次运行后都会更改。可自定义。comment
(\(string\)):不指定log_dir
, 文件夹后缀。filename_suffix
(\(int\)):log_dir
目录中所有事件文件名后缀。
add_scalar
:记录标量。
tag
(\(string\)):标签名。scalar_value
(\(float、string、blobname\)):要记录的标量。global_step
(\(int\)):轮次。new_stype
(\(boolean\)):使用新样式(张亮字段)还是旧样式(\(simple\_value\) 字段)。新样式可能有更快的加载速度。
add_scalars
:记录多个标量。
main_tag
(\(string\)):多个标签名。tag_scalar_dict
(\(dict\)):存储标签和对应的键值。global_step
(\(int\)):轮次。
add_histogram
:统计直方图与多分位数折线图
tag
(\(string\)):标签名。values
(\(torch.Tensor、numpy.array、string、blobname\)):构建直方图的值。global_step
(\(int\)):轮次。bins
(\(string\)):取值 \(tensorflow、auto、fd\) 等。这决定如何制作垃圾箱。
add_image
:显示图像
tag
(\(string\)):标签名。img_tensor
(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。global_step
(\(int\)):轮次。dataformats
(\(string\)):\(CHW、HWC、HW、WH\) 图像数据的格式。
add_images
:批量显示图像
tag
(\(string\)):标签名。img_tensor
(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。global_step
(\(int\)):轮次。dataformats
(\(string\)):\(NCHW、NHWC、CHW、HWC、HW、WH\) 图像数据的格式。
add_graph
:查看模型图
model
(\(torch.nn.Model\)):模型,必须是nn.Module
。input_to_model
(\(torch.Tensor、torch.Tensor列表\)):输出给模型的数据。verbose
(\(bool\)):是否打印计算图结构信息。
写完记得写 writer.close()
2. 调用方法
2.1 创建接口
writer = SummaryWriter('runs')
2.2 记录多个标量
writer.add_scalars('name', {'dic': val}, epoch)
2.3 统计直方图
writer.add_histogram('weight', self.fc.weight, epoch)
2.4 批次显示图像
writer.add_images(“Cifar10”, img_batch, epoch, 'CHW')
2.5 查看模型图
writer.add_graph(model=net,input_to_model=torch.randn(1,3, 224, 224).to(device))
来自: