torch.utils.tensorboard
TensorBoard 神经网络可视化工具
1. pytorch 官方文档解析
from torch.utils.tensorboard import SummaryWriterclass
class SummaryWriter(object)
def __init__(
self,
log_dir=None,
comment='',
purge_step=None,
max_queue=10,
flush_secs=120,
filename_suffix=''
)
def add_scalar(
self,
tag,
scalar_value,
global_step=None,
walltime=None,
new_style=False,
double_precision=False
)
def add_scalars(
self,
main_tag,
tag_scalar_dict,
global_step=None,
walltime=None
)
def add_histogram(
self,
tag,
values,
global_step=None,
bins="tensorflow",
walltime=None,
max_bins=None,
)
def add_image(
self,
tag,
img_tensor,
global_step=None,
walltime=None,
dataformats="CHW"
)
def add_images(
self,
tag,
img_tensor,
global_step=None,
walltime=None,
dataformats="NCHW"
)
def add_graph(
self,
model,
input_to_model=None,
verbose=False,
use_strict_trace=True
)
将条目直接写入 \(log\_dir\) 中的时间文件以供 \(TensorBoard\) 使用。
\(SummaryWriter\) 类提供了一个高级 \(API\),用于在给定目录中创建事件文件并向其添加摘要和事件。
__init__:
log_dir(\(string\)):保存目录位置。默认为runs/current_datetime_hostname,每次运行后都会更改。可自定义。comment(\(string\)):不指定log_dir, 文件夹后缀。filename_suffix(\(int\)):log_dir目录中所有事件文件名后缀。
add_scalar:记录标量。
tag(\(string\)):标签名。scalar_value(\(float、string、blobname\)):要记录的标量。global_step(\(int\)):轮次。new_stype(\(boolean\)):使用新样式(张亮字段)还是旧样式(\(simple\_value\) 字段)。新样式可能有更快的加载速度。
add_scalars:记录多个标量。
main_tag(\(string\)):多个标签名。tag_scalar_dict(\(dict\)):存储标签和对应的键值。global_step(\(int\)):轮次。
add_histogram:统计直方图与多分位数折线图
tag(\(string\)):标签名。values(\(torch.Tensor、numpy.array、string、blobname\)):构建直方图的值。global_step(\(int\)):轮次。bins(\(string\)):取值 \(tensorflow、auto、fd\) 等。这决定如何制作垃圾箱。
add_image:显示图像
tag(\(string\)):标签名。img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。global_step(\(int\)):轮次。dataformats(\(string\)):\(CHW、HWC、HW、WH\) 图像数据的格式。
add_images:批量显示图像
tag(\(string\)):标签名。img_tensor(\(torch.Tensor、numpy.array、string、blobname\)):图像数据。global_step(\(int\)):轮次。dataformats(\(string\)):\(NCHW、NHWC、CHW、HWC、HW、WH\) 图像数据的格式。
add_graph:查看模型图
model(\(torch.nn.Model\)):模型,必须是nn.Module。input_to_model(\(torch.Tensor、torch.Tensor列表\)):输出给模型的数据。verbose(\(bool\)):是否打印计算图结构信息。
写完记得写 writer.close()
2. 调用方法
2.1 创建接口
writer = SummaryWriter('runs')
2.2 记录多个标量
writer.add_scalars('name', {'dic': val}, epoch)
2.3 统计直方图
writer.add_histogram('weight', self.fc.weight, epoch)
2.4 批次显示图像
writer.add_images(“Cifar10”, img_batch, epoch, 'CHW')
2.5 查看模型图
writer.add_graph(model=net,input_to_model=torch.randn(1,3, 224, 224).to(device))
来自:

浙公网安备 33010602011771号