计算图

机器学习程序从前端到后端需要编译成不同的 IR 来获得更好的优化性能,在 mlsys 中这个 IR 就是计算图。

对机器学习程序描述的调度执行、自动更新模型所需的梯度都需要依赖计算图。

一个计算图的逻辑结构大概是下图右边的前向部分:

也就是由神经网络中的算子、各算子的输入、以及算子之间的计算顺序组成。有了计算图就可以从逻辑上表达一个神经网络前向的计算过程并由此得到在各个硬件平台的 IR 从而进行更多的优化。

简单地说,计算图实现了以下关键功能:

  • 统一的计算过程表达: 不同的计算后端有不同的模型表达数据结构,深度学习框架需要将机器学习程序统一表达为某种 IR 方便进行一些共性的优化(如常量替换、死代码消除、循环优化等),比如 LLVM IR、TensorIR;
  • 自动化计算梯度 : 用户的模型训练程序接收训练数据集的数据样本,通过神经网络前向计算,最终计算出损失值。根据损失值,机器学习系统为每个模型参数计算出梯度来更新模型参数。考虑到用户可以写出任意的模型拓扑和损失值计算方法,计算梯度的方法必须通用并且能实现自动运行。计算图可以辅助机器学习系统快速分析参数之间的梯度传递关系,实现自动化计算梯度的目标。因为计算图中记录了每个输出是有哪些输入、哪些算子得到的,可以很方便的计算梯度;
  • 分析模型变量生命周期 : 在用户训练模型的过程中,系统会通过计算产生临时的中间变量,如前向计算中的激活值和反向计算中的梯度。前向计算的中间变量可能与梯度共同参与到模型的参数更新过程中。通过计算图,系统可以准确分析出中间变量的生命周期(一个中间变量生成以及销毁时机),从而帮助框架优化内存管理。举个例子:在动态 tensor 重构造(DTR)中,可以根据动态图的计算路径获得哪些中间结果的值对反向的梯度更新不重要(比如 elemwise 中从 host load 进来一个 tensor),从而可以将这些 tensor 从显存中剔除出去从而节省显存;
  • 优化程序执行 : 用户给定的模型程序具备不同的网络拓扑结构。机器学习框架利用计算图来分析模型结构和算子执行依赖关系,并自动寻找算子并行计算的策略,从而提高模型的执行效率。举个例子:在计算图中我们可能经常遇到 conv 之后接一个 bn 这样的结构,类似于传统编译器对程序做的优化,这里 conv + bn 可以通过算子融合 fuse 成一个算子 ConvBn,从而在发 kernel 的时候只需要发一次 kernel 而不是两次,减少了通信和计算时间。

参考

posted @ 2023-08-05 23:16  machine_gun_lin  阅读(155)  评论(0编辑  收藏  举报