PyTorch消除训练瓶颈 提速技巧

PyTorch消除训练瓶颈 提速技巧

1. 硬件层面

CPU的话尽量看主频比较高的,缓存比较大的,核心数也是比较重要的参数。

显卡尽可能选现存比较大的,这样才能满足大batch训练,多卡当让更好。

内存要求64G,4根16G的内存条插满绝对够用了。

主板性能也要跟上,否则装再好的CPU也很难发挥出全部性能。

电源供电要充足,GPU运行的时候会对功率有一定要求,全力运行的时候如果电源供电不足对性能影响还是比较大的。

存储如果有条件,尽量使用SSD存放数据,SSD和机械硬盘的在训练的时候的读取速度不是一个量级。笔者试验过,相同的代码,将数据移动到SSD上要比在机械硬盘上快10倍。

操作系统尽量用Ubuntu就可以(实验室用)

如何实时查看Ubuntu下各个资源利用情况呢?

  • GPU使用 watch -n 1 nvidia-smi 来动态监控
  • IO情况,使用 iostat 命令来监控
  • CPU情况,使用 htop 命令来监控

2. 如何测试训练过程的瓶颈

如果现在程序运行速度很慢,那应该如何判断瓶颈在哪里呢?PyTorch中提供了工具,非常方便的可以查看设计的代码在各个部分运行所消耗的时间。

可以使用PyTorch中bottleneck工具,具体使用方法如下:

瓶颈测试:https://pytorch.org/docs/stable/bottleneck.html

也可以用以下代码分析:

def test_loss_profiling():
    loss = nn.BCEWithLogitsLoss()
    with torch.autograd.profiler.profile(use_cuda=True) as prof:
        input = torch.randn((8, 1, 128, 128)).cuda()
        input.requires_grad = True

        target = torch.randint(1, (8, 1, 128, 128)).cuda().float()

        for i in range(10):
            l = loss(input, target)
            l.backward()
    print(prof.key_averages().table(sort_by="self_cpu_time_total"))

3. 图片解码

PyTorch中默认使用的是Pillow进行图像的解码,但是其效率要比Opencv差一些,如果图片全部是JPEG格式,可以考虑使用TurboJpeg库解码。具体速度对比如下图所示:

4. 数据增强加速

在PyTorch中,通常使用transformer做图片分类任务的数据增强,而其调用的是CPU做一些Crop、Flip、Jitter等操作。

如果你通过观察发现你的CPU利用率非常高,GPU利用率比较低,那说明瓶颈在于CPU预处理,可以使用Nvidia提供的DALI库在GPU端完成这部分数据增强操作。

Dali链接:https://github.com/NVIDIA/DALI

文档也非常详细:

Dali文档:https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/index.html

当然,Dali提供的操作比较有限,仅仅实现了常用的方法,有些新的方法比如cutout需要自己搞。

具体实现可以参考这一篇:https://zhuanlan.zhihu.com/p/77633542

5. data Prefetch

Nvidia Apex中提供的解决方案

参考来源:https://zhuanlan.zhihu.com/p/66145913

Apex提供的策略就是预读取下一次迭代需要的数据。

class data_prefetcher():
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
        self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
        # With Amp, it isn't necessary to manually convert data to half.
        # if args.fp16:
        #     self.mean = self.mean.half()
        #     self.std = self.std.half()
        self.preload()

    def preload(self):
        try:
            self.next_input, self.next_target = next(self.loader)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)
            self.next_target = self.next_target.cuda(non_blocking=True)
            # With Amp, it isn't necessary to manually convert data to half.
            # if args.fp16:
            #     self.next_input = self.next_input.half()
            # else:
            self.next_input = self.next_input.float()
            self.next_input = self.next_input.sub_(self.mean).div_(self.std)

在训练函数中进行如下修改:

原先是:

training_data_loader = DataLoader(
    dataset=train_dataset,
    num_workers=opts.threads,
    batch_size=opts.batchSize,
    pin_memory=True,
    shuffle=True,
)
for iteration, batch in enumerate(training_data_loader, 1):
    # 训练代码

修改以后:

data, label = prefetcher.next()
iteration = 0
while data is not None:
    iteration += 1
    # 训练代码
    data, label = prefetcher.next()

6. 其他细节

batch_images = batch_images.pin_memory() 
Batch_labels = Variable(batch_labels).cuda(non_blocking=True) 

PyTorch的DataLoader有一个参数pin_memory,使用固定内存,并使用non_blocking=True来并行处理数据传输。

torch.backends.cudnn.benchmark=True

及时释放掉不需要的显存、内存。

如果数据集比较小,直接将数据复制到内存中,从内存中读取可以极大加快数据读取的速度。

调整workers数量,过少的线程读取数据会导致速度非常慢,过多线程读取数据可能会由于阻塞也导致速度非常慢。所以需要根据自己机器的情况,尝试不同数量的workers,选择最合适的数量。一般设置为 cpu 核心数或gpu数量

编码的时候要注意尽可能减少CPU和GPU之间的数据传输,使用类似numpy的编码方式,通过并行的方式来处理,可以提高性能。

7. 参考文献

【1】https://zhuanlan.zhihu.com/p/66145913

【2】https://pytorch.org/docs/stable/bottleneck.html

【3】https://blog.csdn.net/dancer__sky/article/details/78631577

【4】https://sagivtech.com/2017/09/19/optimizing-pytorch-training-code/

【5】https://zhuanlan.zhihu.com/p/77633542

【6】https://github.com/NVIDIA/DALI

【7】https://zhuanlan.zhihu.com/p/147723652

【8】https://www.zhihu.com/question/356829360/answer/907832358

posted @ 2022-03-03 18:53  梁君牧  阅读(462)  评论(0编辑  收藏  举报