MindSpore代码评析(十八)图算融合优化之IR_Fission代码实现详解
一、算子拆分基础概念
绝大部分算子都是可以拆分为若干更为基本的算子组成的的子图。在实际使用中,MindSpore会对计算图进行三个阶段的处理。
1)为了避免对已有网络的直接修改,并实现优化,我们会自动将复合算子展开成复合算子图;
2)将相邻的复合算子子图聚合成更大的聚合算子子图;
3)根据融合策略,将聚合算子子图拆分拆分为多个融合算子子图。
拆分后算子虽然数目增加,但有着更好的计算和内存局部性,从而获得更好的执行性能。
实现算子拆分的代码位于:ir_fission ,为了满足用户需求,MindSpore提供了多类型的算子切分代码具体讲解如下。
二、代码实现
(1)AddnFission
Addn算子是逐元素计算所有输入张量的加法,它要求所有输入张量必须具有相同的形状。
- 输入:由多个张量构成的元组(Tuple)或列表(List)
-
输出:与输入的每个张量具有相同形状和数据类型的张量
-
实例:输入张量数组[1,2,3]和[4 ,5,6],执行addn操作,输出结果为[5,7,9],即对数组中每个张量都逐一进行了加法操作。
x = Tensor(np.array([1, 2, 3]), mindspore.float32)
y = Tensor(np.array([4, 5, 6]), mindspore.float32)
output = net(x, y)
print(output)
在了解Addn后,我们来看MindSpore是如何实现对Addn算子的拆分的:
const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
//异常控制,防止输入为空
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// 实际输入从索引1开始 所以减去1
size_t origin_input_size = cnode->inputs().size() - 1;
//源输入大小需要大于输入因子
if (origin_input_size <= inputs_divisor_) {
return nullptr;
}
//创建新节点
CNodePtr new_cnode = cnode;
///对每个输入都进行拆分
while (origin_input_size > inputs_divisor_) {
MS_EXCEPTION_IF_NULL(new_cnode);
std::vector<AnfNodePtr> base_addn_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimAddN->name()))};
size_t cur_input_index = 1
// 通过输入因子对每个输入都进行拆分
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_;
}
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_addn_inputs.emplace_back(new_cnode->input(i));
}
CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs);
MS_EXCEPTION_IF_NULL(base_addn);
base_addn->set_scope(new_cnode->scope());
base_addn->set_abstract(new_cnode->abstract());
//为新节点设置参数
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_addn_inputs.size() - 1)), base_addn);
//输入大小的设置
std::vector<int64_t> dyn_input_sizes{SizeToLong(base_addn_inputs.size() - 1)};
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn);
new_cnode = base_addn;
origin_input_size = base_addn->inputs().size() - 1;
}
//返回最终分离出的计算图
return new_cnode;
}
(2)ConcatFission
Concat用于将两个数据集进行拼接,合并为一个数据集,要求输入数据集的列名,列数据类型和列数据的排列相同。
- 输入:两个对应相同的数据集
- 输出:拼接后的数据集
Concat算子的拆分代码解析如下:
const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
//异常控制,防止输入为空
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsDynamicShape(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// 实际输入从索引1开始 所以减去1
size_t origin_input_size = cnode->inputs().size() - 1;、
// 源输入大小需要大于输入因子
if (origin_input_size <= inputs_divisor_) {
return nullptr;
}
CNodePtr new_cnode = cnode;
//对每个输入进行分割
while (origin_input_size > inputs_divisor_) {
MS_EXCEPTION_IF_NULL(new_cnode);
std::vector<AnfNodePtr> base_concat_inputs{NewValueNode(std::make_shared<Primitive>(prim::kPrimConcat->name()))};
size_t cur_input_index = 1;
// 通过输入因子对每个输入都进行拆分
while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) {
base_concat_inputs.push_back(CreateNewConcat(func_graph, new_cnode, cur_input_index, inputs_divisor_));
cur_input_index += inputs_divisor_;
}
for (size_t i = cur_input_index; i <= origin_input_size; i++) {
base_concat_inputs.emplace_back(new_cnode->input(i));
}
CNodePtr base_concat = func_graph->NewCNode(base_concat_inputs);
MS_EXCEPTION_IF_NULL(base_concat);
base_concat->set_scope(new_cnode->scope());
base_concat->set_abstract(new_cnode->abstract());
// 为差分后的子图设置参数
if (AnfAlgo::HasNodeAttr(kAttrAxis, new_cnode)) {
AnfAlgo::CopyNodeAttr(kAttrAxis, new_cnode, base_concat);
}
if (AnfAlgo::HasNodeAttr(kAttrT, new_cnode)) {
AnfAlgo::CopyNodeAttr(kAttrT, new_cnode, base_concat);
}
//为新节点设置参数
AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToLong(base_concat_inputs.size() - 1)), base_concat);
AnfAlgo::SetNodeAttr(kAttrInputNums,