ONNX Runtime 源码阅读:模型结点串行执行顺序的确定(转)
概要
ONNX模型中的结构是一个有向图,包含了很多节点。每个节点执行一个特定的操作,最终就得到了推理结果。ONNX模型格式标准并没有要求所有节点按照拓扑顺序来存储,进行模型解析的时候也基本不要求解析出来的节点一定要符合拓扑顺序排列。有些模型很简单,从输入到输出,可能只有一条通路;有些模型很复杂,不仅输入和输出节点间存在多条通路,还有可能存在多个输入和输出节点。ONNX Runtime 是如何确定模型中各个节点执行的先后顺序的呢?怎么确保某个节点被执行之前,其所有先导节点都已经被执行?这就是今天需要解决的疑惑。ONNX Runtime 执行模型的方式主要有两种:串行和并行,好像有点废话了。通过初始化的时候传递个InferenceSession的构造函数的结构体SessionOptions中的ExecutionMode成员来控制。今天主要研究串行执行时节点执行顺序。
涉及文件
onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
onnxruntime\onnxruntime\core\session\inference_session.cc
onnxruntime\onnxruntime\core\framework\sequential_executor.cc
onnxruntime\onnxruntime\core\framework\session_state_initializer.cc
onnxruntime\onnxruntime\core\graph\graph_viewer.cc
onnxruntime\onnxruntime\core\framework\session_state.cc
onnxruntime\onnxruntime\core\graph\graph.cc
正文
举个栗子,有一个简单的模型,如图1所示:
在这个简单的模型里面,一共有六个节点,从输入到输出有两条通路。由于ONNX模型格式标准并没有要求所有节点按照拓扑顺序来存储,因此模型再次加载到内存以后,节点的顺序的排列完全是随机的,有可能是1、3、2、4、6、5,也可能是其他的顺序。因此,必须要先确定节点的拓扑结构并按照结构存储起来,这样才能在跑的时候知道那个是输入,哪些节点必须先跑完。
代码调用
在上一篇文章ONNX Runtime 源码阅读:模型推理过程概览中我们说过,模型节点执行顺序的确定是在InferenceSession实例化完毕后,在初始化阶段完成的。
// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") .def( "load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) { OrtPybindThrowIfError(sess->Load()); InitializeSession(sess, provider_types); }, R"pbdoc(Load a model saved in ONNX format.)pbdoc")
从上面代码中可以看到,初始化也分为两个阶段:1)模型加载 2)InferenceSession实例初始化。
模型加载?模型不是在生成InferenceSession实例的时候已经加载到内存了么?其实在InferenceSession实例化阶段加载的模型知识编译proto文件得到的类ModelProto的一个实例,直接使用还是不太方便,因此还需要对它进行进一步解析和封装,OrtPybindThrowIfError(sess->Load());这句话主要做的就是这件事。
我们接着来看InitializeSession(sess, provider_types);:
// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc void InitializeSession(InferenceSession* sess, const std::vector<std::string>& provider_types) { if (provider_types.empty()) { // use default registration priority. RegisterExecutionProviders(sess, GetAllProviders()); } else { RegisterExecutionProviders(sess, provider_types); } OrtPybindThrowIfError(sess->Initialize()); }
可以看到,InitializeSession(sess, provider_types)在注册Provider后,最终调用到了onnxruntime\onnxruntime\core\session\inference_session.cc中类InferenceSession的Initiablize()方法。
Initiablize()方法体非常长,但是有两行非常刺眼,session_initializer.CreatePlan; InitializeSubgraphSessions(graph, *session_state_),字面意思就是创建执行计划,开个上帝视角执行顺序这的是在这里创建的。由于方法体很长,这就贴一部分重要的好了:
// onnxruntime\onnxruntime\core\session\inference_session.cc # InferenceSession::Initialize() onnxruntime::Graph& graph = model_->MainGraph(); // Collect the kernel registries from execution provider instances; // There are 2 kinds of kernel registries with priority from high to low as below, // 1. Custom execution provider type specific kernel registries. // 2. common execution provider type specific kernel registries. // The 1st and 2nd ones are shared across sessions. // The 1st ones should have already been registered via session-level API into KernelRegistryManager. // // Register 2nd registries into KernelRegistryManager. ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_)); SessionStateInitializer session_initializer(session_options_.enable_mem_pattern, model_location_, graph, *session_state_, execution_providers_, kernel_registry_manager_); // create SessionState for subgraphs as it's needed by the transformers ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_)); // apply any transformations to the main graph and any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, *graph_transformation_mgr_, execution_providers_, kernel_registry_manager_, insert_cast_transformer_, *session_state_)); // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); if (!session_options_.optimized_model_filepath.empty()) { // Serialize optimized ONNX model. ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath)); if (session_options_.graph_optimization_level >= TransformerLevel::Level3) { LOGS(*session_logger_, WARNING) << "Serializing Optimized ONNX model with Graph Optimization" " level greater than ORT_ENABLE_EXTENDED. The generated" " model may contain hardware and execution provider specific" " optimizations, and should only be used in the same environment" " the model was optimized for."; } } ORT_RETURN_IF_ERROR_SESSIONID_(session_initializer.CreatePlan(nullptr, nullptr, session_options_.execution_mode)); // handle any subgraphs ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, *session_state_)); is_inited_ = true;
但是,开上帝视角之前,我们是怎么知道这一段就是我们心心念念的代码?一方面,我们从模型推理时的方法调用中发现执行的时候发现直接取到了一个已经按照拓扑顺序存储的结点序列,
// onnxruntime\onnxruntime\core\framework\sequential_executor.cc#SequentialExecutor::Execute() const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();
和这里的CreatePlan可谓遥相呼应,更重要的是,这个序列是从SessionState的实例中取出来的,有出肯定有入,我们需要紧盯着这个序列什么时候被放进去的。恰好,在SessionStateInitializer的实例中SessionState和模型中取出的主图同时出现,让人不得不将焦点聚集在这;另一方面,这里的代码命名非常好,可谓顾名思义。不禁让人感叹,写的出代码是一回事儿,让人容易看懂又是另一回事儿了,毕竟,良好的代码不仅要高效还要易读。
代码的开始,先从模型中取到主图,然后将主图和一个SessionState的实例session_state_和其他参数一起传递给了SessionStateInitializer的构造函数,该构造函数仅仅是做了些简单的赋值操作,然后就执行到了SessionStateInitializer的方法CreatePlan()。
// onnxruntime\onnxruntime\core\framework\session_state_initializer.cc#SessionStateInitializer::CreatePlan() common::Status SessionStateInitializer::CreatePlan( const Node* parent_node, const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args, ExecutionMode execution_mode) { session_state_.SetGraph(graph_); const GraphViewer* graph_viewer = session_state_.GetGraphViewer(); // populate the SessionState OrtValueNameIdxMap const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap(); // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs. std::vector<const NodeArg*> valid_outer_scope_node_args; if (outer_scope_node_args) { std::for_each(outer_scope_node_args->cbegin(), outer_scope_node_args->cend(), [&ort_value_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) { int idx; if (ort_value_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) { valid_outer_scope_node_args.push_back(node_arg); }; }); } std::unique_ptr<SequentialExecutionPlan> exec_plan; SequentialPlannerContext context(execution_mode); ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer, valid_outer_scope_node_args, execution_providers_, kernel_registry_manager_, ort_value_name_idx_map, context, exec_plan)); session_state_.SetExecutionPlan(std::move(exec_plan)); const auto* exec_plan_ptr = session_state_.GetExecutionPlan(); ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first."); // omitting other code // .... }
按照我们之前的理论,我们继续跟随SequentialPlanner::CreatePlan()
这个方法:
// onnxruntime\onnxruntime\core\framework\allocation_planner.cc#SequentialPlanner::CreatePlan() Status SequentialPlanner::CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer, const std::vector<const NodeArg*>& outer_scope_node_args, const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry, const OrtValueNameIdxMap& ort_value_name_idx_map, const ISequentialPlannerContext& context, std::unique_ptr<SequentialExecutionPlan>& plan) { // allocate/reset here so we know it's clean plan = onnxruntime::make_unique<SequentialExecutionPlan>(); PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_registry, ort_value_name_idx_map, context, *plan); return planner.CreatePlan(); }
这个方法生成一个PlannerImpl
实例后,接着套娃:
// onnxruntime\onnxruntime\core\framework\allocation_planner.cc#PlannerImpl::CreatePlan() Status PlannerImpl::CreatePlan() { auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder(); int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; Initialize(p_graph_nodes.size(), static_cast<size_t>(num_ml_values)); // Determine execution order: we use the default topological sort order for now. We can later // explore more efficient orderings (from a memory usage perspective). for (auto n : p_graph_nodes) { plan_.execution_plan.emplace_back(n); } // omitting some code // ...... }
看到auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder();这句,有种守的云开见月明的感觉。可惜,进去一看,里面已经是一个进行了拓扑排序的列表。没道理啊?怎么可能在我们眼皮底下偷摸的把拓扑关系做了?难道我们上帝视角也出了问题?答案当然不是,只不过是因为保存网络节点拓扑关系的SessionState对象非常勤奋,在它获取到模型结构图的时候,就把节点按拓扑排序排了,根本不管你deadline是什么时候。
我们回到上面SessionStateInitializer::CreatePlan()这个方法,方法体第一句session_state_.SetGraph(graph_);把模型结构图给了SessionState,而SessionState马上又把模型结构图给了他的小弟GraphViewer,进入GraphViewer我们终于发现,寻他千百度的拓扑排序就在这里。从字面上看,graph.ReverseDFSFrom()用的拓扑排序算法就是深度优先搜索算法。
进入SessionState.SetGraph():
// onnxruntime\onnxruntime\core\framework\session_state.cc Status SessionState::SetGraph(const Graph& graph) { graph_viewer_ = onnxruntime::make_unique<onnxruntime::GraphViewer>(graph); auto& logger = Logger(); // use graph_viewer_ to initialize ort_value_name_idx_map_ LOGS(logger, INFO) << "SaveMLValueNameIndexMapping"; int idx = 0; // omitted some code // ... } // onnxruntime\onnxruntime\core\graph\graph_viewer.cc GraphViewer::GraphViewer(const Graph& graph) { graph_ = &graph; std::vector<const Node*> leaf_nodes; for (auto& node : graph_->Nodes()) { if (node.OutputNodesBegin() == node.OutputNodesEnd()) { // This is a leaf node (without any output node). leaf_nodes.push_back(&node); } } graph.ReverseDFSFrom( leaf_nodes, nullptr, [this](const Node* n) { nodes_in_topological_order_.push_back(n->Index()); }, NodeCompare()); for (auto& node : graph_->Nodes()) { if (node.InputEdgesBegin() == node.InputEdgesEnd()) { root_nodes_.push_back(node.Index()); } } }
算法
下面让我们来看看具体的算法实现的吧:
// onnxruntime\onnxruntime\core\graph\graph.cc#Graph::ReverseDFSFrom() void Graph::ReverseDFSFrom(const std::vector<const Node*>& from, const std::function<void(const Node*)>& enter, const std::function<void(const Node*)>& leave, const std::function<bool(const Node*, const Node*)>& comp) const { using WorkEntry = std::pair<const Node*, bool>; // bool represents leave or not std::vector<WorkEntry> stack(from.size()); for (size_t i = 0; i < from.size(); i++) { stack[i] = WorkEntry(from[i], false); } std::vector<bool> visited(MaxNodeIndex(), false); while (!stack.empty()) { const WorkEntry last_entry = stack.back(); stack.pop_back(); const Node& n = *last_entry.first; if (last_entry.second) { // leave node leave(&n); continue; } if (visited[n.Index()]) continue; visited[n.Index()] = true; if (enter) enter(&n); if (leave) stack.emplace_back(&n, true); if (comp) { std::vector<const Node*> sorted_nodes; for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) { sorted_nodes.push_back(&(*iter)); } std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp); for (const auto* in : sorted_nodes) { const NodeIndex idx = in->Index(); if (!visited[idx]) { stack.emplace_back(in, false); } } } else { for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) { const NodeIndex idx = (*iter).Index(); if (!visited[idx]) { stack.emplace_back(GetNode(idx), false); } } } } }
算法中通过一个站存储节点,每个节点有一个标志位表示该节点是否可以被取走放入拓扑队列,我们可以称之为可入队列标志,另外再用一个列表表示某个节点是否已经被访问过,我们可以称之为已访问标志。
与一般DFS略有区别的地方,就是它不需要先找到根节点,给定任意一个节点,它最终都能得到一个合理的拓扑列表。它是怎么实现的呢?很简单,直接在存储节点的栈上进行操作:
- 开始的时候节点随机入栈,可如队列标志和已访问标志都清除;
- 栈顶元素出栈,如果:
- 可入队标志位被设置,则该元素进入拓扑队列,重新开始第二步;
- 如果该节点已访问标志位被设置,说明该节点已经进入拓扑队列,重新开始第二步;
- 可入队标志位未被设置,设置该节点的已访问标志位和可入队标志位,重新入栈;并找到该节点所有输入节点,按一定规则排序后,清空输入节点的可入栈标志位,依次入栈。
- 重复第二步直到栈中所有元素都已经弹出并放入拓扑队列中。例如我们最开头的一个简单模型,假设入栈后其排列为:1,4,2,6,5,3。其算法过程如图2图3所示,其中,黄色表示可入队标志被设置,粉红色表示已访问标志被设置,淡蓝色表示拓扑队列里的内容:
最终,我们得到了一个拓扑队列中内容为:1,2,3,4 ,5 ,6。这个队列确保了每个节点被执行的时候,它的输入节点肯定已经被执行。例如,当节点5执行的时候,他的输入节点3和4已经被执行了。
子图
如果模型中还有子图,子图的处理过程也和主图类似,这里就不多说了。
总结
InferenceSession就好似一个统帅,SessionState替他保存推理需要的信息,IExecutor帮他进行推理工作。
————————————————
版权声明:本文为CSDN博主「SunnyZhou-1024」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/ZM_Yang/article/details/104022489
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 记一次.NET内存居高不下排查解决与启示
· DeepSeek 开源周回顾「GitHub 热点速览」
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
2019-03-28 libreoffice python 操作word及excel文档
2018-03-28 tensorflow serving GPU编译问题
2018-03-28 Linux 指令详解 alias 设置别名(转)
2016-03-28 Android 中的 Service 全面总结