TensorFlow Executor解析
前言
TF的单机运行模式下,DirectSession类是主要的会话运行时的类。我们平时在python中调用的session.run最终会调用到会话的入口方法,即
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata)
其内部包含
- 执行器获取
- 输入数据填充(feed)
- 图运行
- 输出数据获取(fetch)
- 张量保存
等5个步骤,其中图运行是会话执行的核心步骤。也就是本文内容所在。DirectSession::Run会调用DirectSession::RunInternal以进行图运行逻辑。
准备工作
RunInternal会首先进行些准备工作。
创建ExecutorBarrier对象,用于实现不同局部图执行器对象的障栅同步操作。父线程构造ExecutorBarrier对象时会指定参与executor数目到该对象内部的计数器变量中,以及传入调用Notification::Notify方法的回调函数。
初始化Executor::Args对象,该对象中Runner类型的runner字段作为执行函数调度功能的函数对象,其可将传输的函数对象调度到线程池内执行。
父进程在调用每个executor的RunAsync(此时会传入Args对象和关于ExecutorBarrier::WhenDone的回调函数)后,调用WaitForNotification进入阻塞等待通知时间。每个executor在达到同步点时使用传入的ExecutorBarrier::WhenDone会互斥的对计数器减一,当计数器减为0时,whenDone会调用Notification::Notify方法通知父线程。
会话运行
Executor类(位于tensorflow/core/common_runtime/executor.h)是会话执行器的抽象,提供异步执行局部图的RunAsync虚方法及其同步封装版本Run方法。其内嵌结构体Args用于提供运行时的参数,结构体内部定义了两个类型别名:用于表示待执行函数对象Closure类型(std::function<void()>),以及用于为执行器实现函数调度功能、能够执行Closure对象的Runner类型(std::function<void(Closure)>)。结构体中Runner类型的runner字段为Executor对象提供了在特定设备上调度执行函数的能力。Executor类的具体实现位于其子类ExecutorImpl类(tensorflow/core/common_runtime/executor.cc)
ExecutorState类用于维护执行器的运行时状态(子图切分后、关联到特定设备的局部图的执行状态),其在ExecutorImpl::RunAsync方法调用时构造,并被ExecuImpl调用其RunAsync方法。
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState(args, this))->RunAsync(std::move(done));
}
ExecutorState::RunAsync会首先初始化ExecutorState内部的设备上下文,然后将所有当前就绪的ops加入到TaggedNodeSeq类型的ready变量中,并通过ScheduleReady(ready, nullptr);语句调用ScheduleReady。下面是ScheduleReady的定义:
void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
TaggedNodeReadyQueue* inline_ready) {
if (ready.empty()) return;
int64 scheduled_nsec = 0;
if (stats_collector_) {
scheduled_nsec = nodestats::NowInNsec();
}
if (inline_ready == nullptr) {
// Schedule to run all the ready ops in thread pool.
for (auto& tagged_node : ready) {
runner_([=]() { Process(tagged_node, scheduled_nsec); });
}
return;
}
const GraphView& gview = impl_->gview_;
const TaggedNode* curr_expensive_node = nullptr;
for (auto& tagged_node : ready) {
const NodeItem& item = *gview.node(tagged_node.node->id());
if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
// Inline this inexpensive node.
inline_ready->push_back(tagged_node);
} else {
if (curr_expensive_node) {
// Dispatch to another thread since there is plenty of work to
// do for this thread.
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_nsec));
}
curr_expensive_node = &tagged_node;
}
}
if (curr_expensive_node) {
if (inline_ready->empty()) {
// Tail recursion optimization
inline_ready->push_back(*curr_expensive_node);
} else {
// There are inline nodes to run already. We dispatch this expensive
// node to other thread.
runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
scheduled_nsec));
}
}
}
因为此时我们RunAsync调用ScheduleReady时传nullptr参数给inline_ready,所以这次对的ScheduleReady的调用只会将ready里所有TaggedNode分别传参给Process函数,并将这些Process通过runner_调度到线程池。
Process函数会创建新的TaggedNodeSeq类型的ready和TaggedNodeReadyQueue类型的inline_ready。首先将对本Process函数调用时传入的一个TaggedNode压入inline_ready,然后当inline_ready不为空时,依次取出inline_ready中的TaggedNode处理。具体的处理如下:
- 调用ExecutorState::PrepareInputs,为节点准备输入张量
- 根据节点kernel_is_async调用异步或者同步的逻辑。以同步为例,执行device->Compute(op_kernel, &ctx);以完成实际计算
- 调用ExecutorStateProcessOutputs,对输出张量加以收集和处理
- 调用ExecutorState::propagateOutputs方法,更新该Process的ready(将新增的就绪的节点加入的ready中),并将输出张量传递给本节点对应的目标节点。
- 调用NodeDone进行后处理,此时会传入ready和inline_ready的指针,NodeDone函数会执行ScheduleReady(ready, inline_ready),将昂贵的节点分发到另起的新Process中,将其他的节点加入本Process的inline_ready中。
当NodeDone运行结束后,函数返回到Process中,此时的inline_ready已经不为空了,所以循环判定式为真,程序再次进入循环体执行上述的步骤。
这里说下第5步中对SchedReady的调用。因为Process中对ScheduleReady的调用会会给inline_ready形参传入非空实参,所以这次的ScheduleReady会和ExecutorState::RunAsync调用时的ScheduleReady表现的不一样,具体逻辑为:会遍历传入的ready,将不昂贵的节点加入到inline_ready中,将昂贵的节点通过runner_分发到新的Process提交给线程池执行(尾递归优化代码的要实现的目的是:如果ready没有不昂贵的节点的话,最后一个昂贵的节点会留在本Process中执行)
当Process中inline_ready为空后,会调用ScheduleFinish,然后当本ExecutorState中的节点都执行完后,到汇合点,调用作为作为ExecutorState::RunAsync参数传入的回调函数,从而触发障栅同步。
ps:因为runner_是个函数对象(c++中,函数也可以是一等公民了),所以可以对其使用()操作符,对其调用。因为runner_的参数是闭包(就是输入空,输出也为空的函数对象),所以需要匿名函数或者bind封装下对Process函数的调用。
参考资料
《深入理解TensorFlow架构设计与实现原理》第四部分 核心揭秘篇