MindSpore代码评析(十八)图算融合优化之IR_Fission代码实现详解

一、算子拆分基础概念

绝大部分算子都是可以拆分为若干更为基本的算子组成的的子图。在实际使用中,MindSpore会对计算图进行三个阶段的处理。

1)为了避免对已有网络的直接修改,并实现优化,我们会自动将复合算子展开成复合算子图;

2)将相邻的复合算子子图聚合成更大的聚合算子子图;

3)根据融合策略,将聚合算子子图拆分拆分为多个融合算子子图。

拆分后算子虽然数目增加,但有着更好的计算和内存局部性,从而获得更好的执行性能。

实现算子拆分的代码位于:ir_fission ,为了满足用户需求,MindSpore提供了多类型的算子切分代码具体讲解如下。

二、代码实现

(1)AddnFission

Addn算子是逐元素计算所有输入张量的加法,它要求所有输入张量必须具有相同的形状。

  • 输入:由多个张量构成的元组(Tuple)或列表(List)
  • 输出:与输入的每个张量具有相同形状和数据类型的张量

  • 实例:输入张量数组[1,2,3]和[4 ,5,6],执行addn操作,输出结果为[5,7,9],即对数组中每个张量都逐一进行了加法操作。

x = Tensor(np.array([1, 2, 3]), mindspore.float32)
y = Tensor(np.array([4, 5, 6]), mindspore.float32)
output = net(x, y)
print(output)

在了解Addn后,我们来看MindSpore是如何实现对Addn算子的拆分的:

const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
  //异常控制,防止输入为空
  MS_EXCEPTION_IF_NULL(func_graph);
  MS_EXCEPTION_IF_NULL(node);
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  // 实际输入从索引1开始 所以减去1
  size_t origin_input_size = cnode->inputs().size() - 1;
  //源输入大小需要大于输入因子
  if (origin_input_size <= inputs_divisor_) {
    return nullptr;
  }
  //创建新节点
  CNodePtr new_cnode = cnode;
  ///对每个输入都进行拆分
  while (origin_input_size > inputs_divisor_) {
    MS_EXCEPTION_IF_NULL(new_cnode);
    std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
    size_t cur_input_index = 1
    // 通过输入因子对每个输入都进行拆分
    while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
      base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_));
      cur_input_index += inputs_divisor_;
    }
    for (size_t i = cur_input_index; i <= origin_input_size; i++) {
      base_addn_inputs.emplace_back(new_cnode->input(i));
    }
    CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
    MS_EXCEPTION_IF_NULL(base_addn);
    base_addn->set_scope(new_cnode->scope());
    base_addn->set_abstract(new_cnode->abstract());
    //为新节点设置参数
    AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_addn_inputs.size() - 1)), base_addn);
    //输入大小的设置
    std::vector<int64_t> dyn_input_sizes{SizeToLong(base_addn_inputs.size() - 1)};
    AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn);
    new_cnode = base_addn;
    origin_input_size = base_addn->inputs().size() - 1;
  }
  //返回最终分离出的计算图
  return new_cnode;
}
(2)ConcatFission

Concat用于将两个数据集进行拼接,合并为一个数据集,要求输入数据集的列名,列数据类型和列数据的排列相同。

  • 输入:两个对应相同的数据集
  • 输出:拼接后的数据集

Concat算子的拆分代码解析如下:

const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
                                        const EquivPtr &) const {
  //异常控制,防止输入为空
  MS_EXCEPTION_IF_NULL(func_graph);
  MS_EXCEPTION_IF_NULL(node);
  if (AnfAlgo::IsDynamicShape(node)) {
    return nullptr;
  }
  auto cnode = node->cast<CNodePtr>();
  MS_EXCEPTION_IF_NULL(cnode);
  // 实际输入从索引1开始 所以减去1
  size_t origin_input_size = cnode->inputs().size() - 1;、
  // 源输入大小需要大于输入因子
  if (origin_input_size <= inputs_divisor_) {
    return nullptr;
  }
  CNodePtr new_cnode = cnode;
  //对每个输入进行分割
  while (origin_input_size > inputs_divisor_) {
    MS_EXCEPTION_IF_NULL(new_cnode);
    std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
    size_t cur_input_index = 1;
     // 通过输入因子对每个输入都进行拆分
    while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
      base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_));
      cur_input_index += inputs_divisor_;
    }
    for (size_t i = cur_input_index; i <= origin_input_size; i++) {
      base_concat_inputs.emplace_back(new_cnode->input(i));
    }
    CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
    MS_EXCEPTION_IF_NULL(base_concat);
    base_concat->set_scope(new_cnode->scope());
    base_concat->set_abstract(new_cnode->abstract());
    // 为差分后的子图设置参数
    if (AnfAlgo::HasNodeAttr(kAttrAxis, new_cnode)) {
      AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat);
    }
    if (AnfAlgo::HasNodeAttr(kAttrT, new_cnode)) {
      AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat);
    }
    //为新节点设置参数
    AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat);
    AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat);
    //输入大小的设置
    std::vector<int64_t> dyn_input_sizes{SizeToLong(base_concat_inputs.size() - 1)};
    AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_concat);
    new_cnode = base_concat;
    origin_input_size = base_concat->inputs().size() - 1;
  }
  //返回最终分离出的计算图
  return new_cnode;
}
posted @ 2021-12-25 11:41  MS小白  阅读(117)  评论(0)    收藏  举报