OneFlow: 启动 Runtime

前言

我们前面介绍了从 Op 到 Job,又从 Job 到 Plan,这篇文章将会分析运行时(Runtime)启动,分析 Actor 是如何启动的。运行时启动的时机,发生在启动 Session 的时候,将 Job 编译成一个物理可以执行的 Plan 之后,就可以按照 Plan 启动运行时,启动 Actor 了。

流程回顾

运行时 Runtime 在什么时候启动的呢?在 Python 调用 StartLazyGlobalSession 的时候,在这个方法初始化全局 OneFlow 对象,将 JobSet 编译成 Plan,使用这个 Plan 启动 Runtime。

Runtime 的初始化流程如下。我们知道 Plan 是物理上可以执行的计算图,Plan 中的节点 TaskProto 则对应计算图上的节点,一个 Task 对应一个 Actor。Runtime 启动的时候,调用 HandoutTasks 将 Task 分发出去,构造 Actor。

// oneflow/core/job/runtime.cpp
Runtime::Runtime(const Plan& plan, const HashMap<std::string, Blob*>& variable_op_name2eager_blob) {
  {
    // NOTE(chengcheng): All runtime Global objects AddPlan
    Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob);
    Global<ThreadMgr>::Get()->AddPlan(plan);
    Global<RuntimeJobDescs>::Get()->AddPlan(plan);
    collective_boxing_executor_plan_token_ =
        Global<boxing::collective::CollectiveBoxingExecutor>::Get()->AddPlan(plan);
  }
  std::vector<const TaskProto*> source_tasks;
  std::vector<const TaskProto*> other_tasks;
  int64_t this_machine_task_num = 0;
  for (const TaskProto& task : plan.task()) {
    if (task.machine_id() != GlobalProcessCtx::Rank()) { continue; }
    if (!HasNonCtrlConsumedRegstDescId(task)) {
      source_tasks.push_back(&task);
    } else {
      other_tasks.push_back(&task);
    }
    auto it = job_id2actor_size_.find(task.job_id());
    if (it == job_id2actor_size_.end()) {
      auto emplace_ret_pair = job_id2actor_size_.emplace(task.job_id(), 0);
      CHECK(emplace_ret_pair.second);
      it = emplace_ret_pair.first;
    }
    it->second++;
    this_machine_task_num++;
  }
  RuntimeCtx* runtime_ctx = Global<RuntimeCtx>::Get();
  runtime_ctx->NewCounter("constructing_actor_cnt", this_machine_task_num);
  HandoutTasks(source_tasks);
  HandoutTasks(other_tasks);
  runtime_ctx->WaitUntilCntEqualZero("constructing_actor_cnt");
  LOG(INFO) << "Actors on this machine constructed";
  OF_SESSION_BARRIER();
  LOG(INFO) << "Actors on every machine constructed";
  for (auto pair : job_id2actor_size_) {
    runtime_ctx->NewCounter(GetRunningActorCountKeyByJobId(pair.first), pair.second);
  }
  SendCmdMsg(source_tasks, ActorCmd::kStart);
}

HandoutTasks 接受 Task 数组作为参数,这些 Task 将会逐个添加到对应的 Thread 里面,最后通过基于消息机制的 ActorMsgBus 发送构造指令来构造 Actor。

// oneflow/core/job/runtime.cpp: 36
void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {
  for (const TaskProto* task : tasks) {
    ActorMsg msg = ActorMsg::BuildCommandMsg(task->task_id(), cmd);
    Global<ActorMsgBus>::Get()->SendMsg(msg);
  }
}

void HandoutTasks(const std::vector<const TaskProto*>& tasks) {
  for (const TaskProto* task : tasks) {
    Global<ThreadMgr>::Get()->GetThrd(task->thrd_id())->AddTask(*task);
  }
  SendCmdMsg(tasks, ActorCmd::kConstructActor);
}

ThreadMgr

下面是 ThreadMgr 的头文件,提供了两个成员方法,和两个函数。前面在启动 Runtime 的时候,会将具有全局信息的 Plan 加入到 ThreadMgr 当中。

  • AddPlan 成员方法,初始化 Thread 对象。
  • GetThrd 成员方法,根据线程 id 获取对应的线程。一个线程对应多个 Actor,启动 Actor 的时候,已经给它分配好了一个线程,Actor 启动的时候去寻找对应的线程即可。
  • SingleThreadLoop,单线程循环调用一个函数
  • MultiThreadLoop,多线程多次调用一个函数。一共启动 thread_num 个线程,但是如果执行的次数多个线程数怎么办?OneFlow 中提供了一个 BalancedSplitter,将线程均匀分开。这个方法使用 BlockingCounter 来进行同步,初始化为线程数量,每执行完一个线程的内容,就减一。当所有的线程都执行完毕,才运行往下走。
// oneflow/core/thread/thread_manager.h
namespace oneflow {

class Plan;

class ThreadMgr final {
 public:
  OF_DISALLOW_COPY_AND_MOVE(ThreadMgr);
  ThreadMgr() = default;
  ~ThreadMgr();

  void AddPlan(const Plan& plan);
  Thread* GetThrd(int64_t thrd_id);

 private:
  friend class Global<ThreadMgr>;

  HashMap<int64_t, std::unique_ptr<Thread>> threads_;
};

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback);
void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback);

#define REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(device, creator) \
  REGISTER_CLASS_CREATOR(int, device, Thread, creator, const StreamId&)

}  // namespace oneflow

// oneflow/core/thread/thread_manager.cpp: 28
namespace oneflow {

ThreadMgr::~ThreadMgr() {
  for (auto& thread_pair : threads_) {
    ActorMsg msg = ActorMsg::BuildCommandMsg(-1, ActorCmd::kStopThread);
    thread_pair.second->GetMsgChannelPtr()->Send(msg);
    thread_pair.second.reset();
    LOG(INFO) << "actor thread " << thread_pair.first << " finish";
  }
}

Thread* ThreadMgr::GetThrd(int64_t thrd_id) {
  auto iter = threads_.find(thrd_id);
  CHECK(iter != threads_.end()) << "thread " << thrd_id << " not found";
  return iter->second.get();
}

void ThreadMgr::AddPlan(const Plan& plan) {
  const int64_t this_rank = GlobalProcessCtx::Rank();
  for (const TaskProto& task : plan.task()) {
    TaskId task_id = DeserializeTaskIdFromInt64(task.task_id());
    StreamId stream_id = task_id.stream_id();
    if (stream_id.device_id().rank() != this_rank) { continue; }
    int64_t thrd_id = SerializeStreamIdToInt64(stream_id);
    if (threads_.find(thrd_id) != threads_.end()) { continue; }
    Thread* thread =
        NewObj<int, Thread, const StreamId&>(stream_id.device_id().device_type(), stream_id);
    CHECK_NOTNULL(thread);
    threads_[thrd_id].reset(thread);
  }
}

void SingleThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
  FOR_RANGE(size_t, i, 0, num) { Callback(i); }
}

void MultiThreadLoop(size_t num, std::function<void(size_t i)> Callback) {
  size_t thread_num = Global<ThreadPool>::Get()->thread_num();
  thread_num = std::min(num, thread_num);
  BalancedSplitter bs(num, thread_num);
  BlockingCounter bc(thread_num);
  FOR_RANGE(size_t, range_id, 0, thread_num) {
    Global<ThreadPool>::Get()->AddWork([&bc, &bs, range_id, Callback] {
      FOR_RANGE(size_t, i, bs.At(range_id).begin(), bs.At(range_id).end()) { Callback(i); }
      bc.Decrease();
    });
  }
  bc.WaitUntilCntEqualZero();
}

}  // namespace oneflow

Thread

接下来考察一下 Thread 这个类,从接口来看,这个类提供的接口允许添加 Task,给 Actor 发送消息。从类成员来看,需要存储各种映射,存储线程对象和 mutex,存储当前线程 id,是否使用本地的消息队列,是否开启 light actor。

// oneflow/core/thread/thread.h
namespace oneflow {

class Thread {
 public:
  OF_DISALLOW_COPY_AND_MOVE(Thread);
  virtual ~Thread();

  void AddTask(const TaskProto&);

  Channel<ActorMsg>* GetMsgChannelPtr() { return &msg_channel_; }

  inline void EnqueueActorMsg(const ActorMsg& msg) {
    if (UseLocalMsgQueue()) {
      local_msg_queue_.push(msg);
    } else {
      msg_channel_.Send(msg);
    }
  }

  template<typename InputIt>
  inline void EnqueueActorMsg(InputIt first, InputIt last) {
    if (UseLocalMsgQueue()) {
      for (auto it = first; it != last; ++it) { local_msg_queue_.push(*it); }
    } else {
      for (auto it = first; it != last; ++it) { msg_channel_.Send(*it); }
    }
  }

  void JoinAllActor() { actor_thread_.join(); }

 protected:
  Thread();
  std::thread& mut_actor_thread() { return actor_thread_; }
  void PollMsgChannel(const ThreadCtx& thread_ctx);
  void set_thrd_id(int64_t val) { thrd_id_ = val; }

 private:
  void ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx);

  inline bool UseLocalMsgQueue() const {
    return local_msg_queue_enabled_ && std::this_thread::get_id() == actor_thread_.get_id();
  }

  HashMap<int64_t, TaskProto> id2task_;
  std::mutex id2task_mtx_;

  std::thread actor_thread_;
  Channel<ActorMsg> msg_channel_;
  HashMap<int64_t, std::unique_ptr<ActorBase>> id2actor_ptr_;
  HashMap<int64_t, int64_t> id2job_id_;
  std::queue<ActorMsg> local_msg_queue_;
  bool local_msg_queue_enabled_;
  int64_t thrd_id_;
  bool light_actor_enabled_;
};

}  // namespace oneflow

Thread 的方法是如何实现的呢?

  • AddTask,加锁,然后直接往映射的数据结构中加东西。
  • ConstructActor,根据 actor_id 构建 Actor,根据是否 light,调用不同的方法初始化。如果是 light,那么调用 TryNewLightActor。如果不是,那么调用 NewActor。构造完成之后,往映射的数据结构中加东西。
  • PollMsgChannel,这个方法非常重要!!它做了什么事呢?看名字,拉取消息。如果消息是 kCmdMsg 类型的,那么这是一条关于控制命令,启动或终止 Actor。如果不是,那么将会把这条消息发送给 Actor 去执行。那么 PollMsgChannel 由谁调用呢?
// oneflow/core/thread/thread.cpp
namespace oneflow {

Thread::Thread() {
  local_msg_queue_enabled_ =
      ParseBooleanFromEnv("ONEFLOW_THREAD_ENABLE_LOCAL_MESSAGE_QUEUE", false);
  light_actor_enabled_ = ParseBooleanFromEnv("ONEFLOW_ACTOR_ENABLE_LIGHT_ACTOR", false);
}

Thread::~Thread() {
  actor_thread_.join();
  CHECK(id2task_.empty());
  msg_channel_.Close();
}

void Thread::AddTask(const TaskProto& task) {
  std::unique_lock<std::mutex> lck(id2task_mtx_);
  CHECK(id2task_.emplace(task.task_id(), task).second);
}

void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
  while (true) {
    if (local_msg_queue_.empty()) {
      CHECK_EQ(msg_channel_.ReceiveMany(&local_msg_queue_), kChannelStatusSuccess);
    }
    ActorMsg msg = std::move(local_msg_queue_.front());
    local_msg_queue_.pop();
    if (msg.msg_type() == ActorMsgType::kCmdMsg) {
      if (msg.actor_cmd() == ActorCmd::kStopThread) {
        CHECK(id2actor_ptr_.empty());
        break;
      } else if (msg.actor_cmd() == ActorCmd::kConstructActor) {
        ConstructActor(msg.dst_actor_id(), thread_ctx);
        continue;
      } else {
        // do nothing
      }
    }
    int64_t actor_id = msg.dst_actor_id();
    auto actor_it = id2actor_ptr_.find(actor_id);
    CHECK(actor_it != id2actor_ptr_.end());
    int process_msg_ret = actor_it->second->ProcessMsg(msg);
    if (process_msg_ret == 1) {
      LOG(INFO) << "thread " << thrd_id_ << " deconstruct actor " << actor_id;
      auto job_id_it = id2job_id_.find(actor_id);
      const int64_t job_id = job_id_it->second;
      id2job_id_.erase(job_id_it);
      id2actor_ptr_.erase(actor_it);
      Global<RuntimeCtx>::Get()->DecreaseCounter(GetRunningActorCountKeyByJobId(job_id));
    } else {
      CHECK_EQ(process_msg_ret, 0);
    }
  }
}

void Thread::ConstructActor(int64_t actor_id, const ThreadCtx& thread_ctx) {
  std::unique_lock<std::mutex> lck(id2task_mtx_);
  auto task_it = id2task_.find(actor_id);
  std::unique_ptr<ActorBase> actor_ptr;
  const TaskProto& task = task_it->second;
  if (light_actor_enabled_) { actor_ptr = TryNewLightActor(task, thread_ctx); }
  if (!actor_ptr) {
    actor_ptr = NewActor(task, thread_ctx);
    LOG(INFO) << "Thread " << thrd_id_ << " construct Actor " << TaskType_Name(task.task_type())
              << " " << actor_id;
  } else {
    LOG(INFO) << "Thread " << thrd_id_ << " construct LightActor "
              << TaskType_Name(task.task_type()) << " " << actor_id;
  }
  CHECK(id2actor_ptr_.emplace(actor_id, std::move(actor_ptr)).second);
  CHECK(id2job_id_.emplace(actor_id, task.job_id()).second);
  id2task_.erase(task_it);
  Global<RuntimeCtx>::Get()->DecreaseCounter("constructing_actor_cnt");
}

}  // namespace oneflow

搜索代码,看看哪些地方调用了 PollMsgChannel。

  • cpu_thread.cpp
  • gpu_thread.cpp

两种方法的结构是类似的,通过 std::thread 来启动 PollMsgChannel,接着这个 Thread 将从消息队列中拉取消息,然后执行。那这些 CpuThread 和 GpuThread 又是如何启动的呢?在 ThreadMgr 的 AddPlan 里面!

// oneflow/core/thread/cpu_thread.cpp
namespace oneflow {

CpuThread::CpuThread(int64_t thrd_id) {
  set_thrd_id(thrd_id);
  mut_actor_thread() = std::thread([this, thrd_id]() {
    OF_PROFILER_NAME_THIS_HOST_THREAD("CPU Actor : (" + std::to_string(thrd_id) + ")");
    ThreadCtx ctx;
#ifdef WITH_CUDA
    ctx.cb_event_chan = nullptr;
#endif  // WITH_CUDA
    PollMsgChannel(ctx);
  });
}

REGISTER_DEVICE_THREAD_CREATOR_WITH_STREAM_ID(DeviceType::kCPU,
                                              ([](const StreamId& stream_id) -> Thread* {
                                                return new CpuThread(
                                                    SerializeStreamIdToInt64(stream_id));
                                              }));

}  // namespace oneflow

Actor

前面分析了线程是如何产生的,线程运行的核心是 Actor。一个线程上有多个 Actor,线程通过轮询消息队列,然后将消息发送给不同的 Actor 来执行。真正干活的 Actor 是如何构造,如何执行的呢?

Actor 的构造很简单,通过 TaskProto 上面的类型,去选择一个对应的 Actor 进行初始化。

// oneflow/core/actor/actor_base.cpp
std::unique_ptr<ActorBase> NewActor(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
  ActorBase* rptr = NewObj<int32_t, ActorBase>(task_proto.task_type());
  const auto& job_descs = *Global<RuntimeJobDescs>::Get();
  rptr->Init(&job_descs.job_desc(task_proto.job_id()), task_proto, thread_ctx);
  return std::unique_ptr<ActorBase>(rptr);
}

Actor 的执行通过 ProcessMsg 方法来进行。前面我们已经看到了线程会轮询消息队列来拉取消息,然后将消息发送给对应的 Actor 进行处理。下面的分析可能有点零碎,核心要抓住一点,如何从拿到消息,到启动 Kernel。

  • 线程通过轮询消息队列,拉取消息。将消息发送给 Actor 去处理,Actor 交给 msg_handler_ 处理。
// 1: success, and actor finish
// 0: success, and actor not finish
int ProcessMsg(const ActorMsg& msg) override { return (this->*msg_handler_)(msg); }
  • msg_handler_ 可以被设置,不同的 Actor 可以设置 msg_handler_ 来处理。
  // Msg Handler
  void set_msg_handler(MsgHandler val) { msg_handler_ = val; }
#define OF_SET_MSG_HANDLER(val)                                   \
  do {                                                            \
    LOG(INFO) << "actor " << actor_id() << " switch to " << #val; \
    set_msg_handler(static_cast<MsgHandler>(val));                \
  } while (0)
  • NaiveActor 中设置 handler,设置了 HandlerNormal。其他各种各样的 Actor 都可以设置 handler 来设置消息不同的处理方法。
void NaiveActor::VirtualActorInit(const TaskProto&) {
  OF_SET_MSG_HANDLER(&NaiveActor::HandlerNormal);
}
  • NaiveActor 中设置的 HandlerNormal 在 Actor 中提供了实现,它调用了 ActUntilFail 来执行 Act 方法。
// oneflow/core/actor/actor.cpp: 258
int Actor::HandlerNormal(const ActorMsg& msg) {
  if (msg.msg_type() == ActorMsgType::kEordMsg) {
    remaining_eord_cnt_ -= 1;
    CHECK(eord_regst_desc_ids_.insert(msg.eord_regst_desc_id()).second);
    if (naive_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
      is_naive_consumed_eord_ = true;
    } else if (inplace_consumed_rs_.HasRegstDescId(msg.eord_regst_desc_id())) {
      is_inplace_consumed_eord_ = true;
    } else {
      NormalProcessCustomizedEordMsg(msg);
    }
  } else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
    if (msg.SrcMachineId() == GlobalProcessCtx::Rank()) {
      Regst* regst = msg.regst();
      if (naive_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
        CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst));
        const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(regst->regst_desc_id());
        CHECK(rdeq.empty() == false);
        if (rdeq.front()->regst_desc()->regst_desc_type().has_data_regst_desc()) {
          NormalProcessNaiveReadableDataRegstMsg(rdeq);
        }
      } else if (inplace_consumed_rs_.HasRegstDescId(regst->regst_desc_id())) {
        CHECK_EQ(0, inplace_consumed_rs_.TryPushBackRegst(regst));
        int64_t out_regst_desc_id = inplace_regst_desc_id_in2out_.at(regst->regst_desc_id());
        CHECK(regst->GetSoleBlob()->dptr()
              == inplace_produced_rs_.Front(out_regst_desc_id)->GetSoleBlob()->dptr());
      } else if (TryUpdtStateAsProducedRegst(regst) == 0) {
        // do nothing
      } else {
        NormalProcessCustomizedReadableRegstMsg(msg);
      }
    } else {
      if (NormalTryProcessReadableMsgFromOtherMachine(msg) == false) {
        // process ctrl msg from other rank
        if (IsConsumedCtrlRegstDescId(msg.regst_desc_id())) {
          Regst* regst = msg.regst();
          CHECK(naive_consumed_rs_.HasRegstDescId(msg.regst_desc_id()));
          CHECK(Global<RegstMgr>::Get()->HasProducerTaskId4RegstDescId(msg.regst_desc_id()));
          CHECK_EQ(0, naive_consumed_rs_.TryPushBackRegst(regst, msg.regst_desc_id()));
          const auto& rdeq = naive_consumed_rs_.RegstDeq4RegstDescId(msg.regst_desc_id());
          CHECK(rdeq.empty() == false);
        } else {
          CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst()), 0);
        }
      }
    }
    ActUntilFail();
  } else if (msg.msg_type() == ActorMsgType::kCmdMsg) {
    CHECK_EQ(msg.actor_cmd(), ActorCmd::kStart);
    ActUntilFail();
  } else {
    UNIMPLEMENTED();
  }
  // handler halts
  bool has_naive_or_inplace = naive_consumed_rs_.total_regst_desc_cnt() != 0
                              || inplace_consumed_rs_.total_regst_desc_cnt() != 0;
  bool naive_or_inplace_eord_and_empty =
      (is_naive_consumed_eord_ || is_inplace_consumed_eord_)
      && (naive_consumed_rs_.available_regst_desc_cnt() == 0
          && inplace_consumed_rs_.available_regst_desc_cnt() == 0);
  bool customized_eord = IsCustomizedReadAlwaysUnReadyFromNow();
  if ((has_naive_or_inplace && naive_or_inplace_eord_and_empty)
      || (!has_naive_or_inplace && customized_eord)) {
    CHECK_EQ(naive_consumed_rs_.available_regst_desc_cnt(), 0);
    AsyncReturnAllCustomizedReadableRegst();
    AsyncSendEORDMsgForAllProducedRegstDesc();
    if (remaining_eord_cnt_ == 0 && total_reading_cnt_ == 0) {
      OF_SET_MSG_HANDLER(nullptr);
      return 1;
    } else {
      OF_SET_MSG_HANDLER(&Actor::HandlerZombie);
      return 0;
    }
  }
  return 0;
}
  • 当读和写都准备好了之后,ActUntilFail 就会调用 Act 方法去执行。
// oneflow/core/actor/actor.cpp
void Actor::ActUntilFail() {
  while (IsReadReady() && IsWriteReady()) {
    Act();

    AsyncSendCustomizedProducedRegstMsgToConsumer();
    AsyncSendNaiveProducedRegstMsgToConsumer();
    AsyncSendInplaceProducedRegstMsgToConsumer();

    AsyncSendCustomizedConsumedRegstMsgToProducer();
    AsyncSendNaiveConsumedRegstMsgToProducer();
    AsyncRetInplaceConsumedRegstIfNoConsumer();

    AsyncSendQueuedMsg();
  }
  // NOTE(liujuncheng): return inplace consumed
  AsyncSendQueuedMsg();
}
  • Act 方法中,将会启动 Kernel,名字叫异步,相对主线程是异步的,因为在这个线程上执行。不过对于当前线程来说,并不是异步的,它是一行一行执行下来的。
void NaiveActor::Act() {
  KernelCtx kernel_ctx = GenDefaultKernelCtx();
  AsyncLaunchKernel(kernel_ctx, [&](int64_t regst_desc_id) -> Regst* { return nullptr; });
}
  • 启动 ExecKernel,ExecKernel 是一个包含了计算信息、存储信息的结构体。Kernel 启动的时候,需要传入 context,还有一个函数体。这个函数的作用是?
// oneflow/core/actor/actor.h: 58
struct ExecKernel {
  std::unique_ptr<const Kernel> kernel;
  HashMap<std::string, BlobInfo> bn_in_op2blob_info;
};

// oneflow/core/actor/actor.cpp: 470
void Actor::AsyncLaunchKernel(const KernelCtx& kernel_ctx,
                              std::function<Regst*(int64_t)> Regst4RegstDescId) {
  for (const ExecKernel& ek : exec_kernel_vec_) {
    ek.kernel->Launch(kernel_ctx, [&](const std::string& bn_in_op) -> Blob* {
      const auto blob_info_it = ek.bn_in_op2blob_info.find(bn_in_op);
      if (blob_info_it == ek.bn_in_op2blob_info.cend()) { return nullptr; }
      const BlobInfo& info = blob_info_it->second;
      if (info.regst_desc_id == -1) { return nullptr; }
      Regst* regst;
      if (info.rs != nullptr) {
        regst = info.rs->Front(info.regst_desc_id);
      } else {
        regst = Regst4RegstDescId(info.regst_desc_id);
      }
      if (regst == nullptr) { return nullptr; }
      if (info.ordinal >= 0) {
        return regst->GetBlobByOrdinal(info.ordinal);
      } else {
        return regst->GetBlobByLbi(info.lbi);
      }
    });
  }
}
  • Kernel Launch -> Forward -> ForwardDataContent。ForwardHeader 应该是做输入的检查。ForwardDataContent 会调用计算的方法。
// oneflow/core/kernel/kernel.cpp: 43
void Kernel::Launch(const KernelCtx& ctx,
                    const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  Forward(ctx, BnInOp2Blob);
}

void Kernel::Forward(const KernelCtx& ctx,
                     const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  if (!blob_access_checker_disabled_) { SetOutputBlobProducerInferAccessChecker(BnInOp2Blob); }
  ForwardHeader(ctx, BnInOp2Blob);
  if ((!kernel_conf_.all_blobs_are_static())
      && IsAllBlobEmpty(op_attribute().output_bns(), BnInOp2Blob) && IsStateless()) {
    return;
  }
  if (!blob_access_checker_disabled_) { SetOutputBlobProducerComputeAccessChecker(BnInOp2Blob); }
  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentStart(this, ctx, BnInOp2Blob));
  ForwardDataContent(ctx, BnInOp2Blob);
  OF_PROFILER_ONLY_CODE(profiler::TraceKernelForwardDataContentEnd(this, ctx, BnInOp2Blob));
  if (!blob_access_checker_disabled_) { SetOutputBlobConsumerAccessChecker(BnInOp2Blob); }
}
  • ForwardDataContent 是 Kernel 提供的虚函数,每个子类实现不一样。UserKernel 和 OpKernel 用于定义扩展算子,OpKernel 中提供了 Compute 虚函数用于计算,需要注意的是 OpKernel 其实并没有继承 Kernel,OpKernel 作为 UserKernel 的一个成员存在。当调用 ForwardDataContent 的时候,它会调用 ForwardUserKernel,进而调用 OpKernel 的计算函数 Compute。
void UserKernel::ForwardDataContent(
    const KernelCtx& ctx, const std::function<Blob*(const std::string&)>& BnInOp2Blob) const {
  ForwardUserKernel(BnInOp2Blob, opkernel_state_.get());
}

void UserKernel::ForwardUserKernel(const std::function<Blob*(const std::string&)>& BnInOp2Blob,
                                   user_op::OpKernelState* opkernel_state) const {
  const bool updated = ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob);

#ifdef WITH_CUDA_GRAPHS
  bool capturing = false;
  if (cuda_graph_ctx_) {
    if (!cuda_graph_ctx_->IsCapturing()) {
      if (cuda_graph_ctx_->IsCaptured() && (!updated)) {
        cuda_graph_ctx_->Launch();
        return;
      }
      capturing = true;
      cuda_graph_ctx_->BeginCapture();
    }
  }
#endif  // WITH_CUDA_GRAPHS

  kernel_->Compute(ctx_.get(), opkernel_state);

#ifdef WITH_CUDA_GRAPHS
  if (cuda_graph_ctx_ && capturing) {
    cuda_graph_ctx_->EndCapture();
    cuda_graph_ctx_->Launch();
  }
#endif  // WITH_CUDA_GRAPHS
}
  • Compute 是如何计算的呢?下面随便找一个 Kernel 来看看,我随便找了个 CpuAddKernel,关注下面的 Compute 函数。它的主要工作是从 KernelComputeContext 取出输入和输出的指针,最后调用 cpu_add 将所有的输入加到输出上。自此我们终于完成了一次 Kernel 的计算。
// oneflow/user/kernels/add_n_kernel.cpp: 22
template<typename T>
void cpu_add(const int64_t n, T* out, const std::vector<const T*>& in) {
  for (int64_t i = 0; i != n; ++i) {
    out[i] = in.at(0)[i];
    for (int32_t j = 1; j < in.size(); ++j) { out[i] += in.at(j)[i]; }
  }
}

// oneflow/user/kernels/add_n_kernel.cpp: 32
template<typename T>
class CpuAddNKernel : public user_op::OpKernel {
 public:
  CpuAddNKernel() = default;
  ~CpuAddNKernel() = default;

  bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }

 private:
  void Compute(user_op::KernelComputeContext* ctx) const override {
    size_t in_num = ctx->inputs().size();

    user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
    int64_t n = out->shape().elem_cnt();
    T* out_dptr = out->mut_dptr<T>();

    std::vector<const T*> in_dptrs(in_num);
    for (int32_t i = 0; i < in_num; ++i) {
      in_dptrs.at(i) = ctx->Tensor4ArgNameAndIndex("in", i)->dptr<T>();
    }

    cpu_add<T>(n, out_dptr, in_dptrs);
  }
};

总结

这篇文章从 Runtime 启动开始,讲了如何启动线程,启动 Actor。线程通过轮询消息队列拉取消息,将消息转发给对应的 Actor 去执行。Actor 将启动 Kernel,Kernel 从 KernelComputeContext 获取输入和输出的信息,最后执行运算。

posted @ 2021-09-05 19:51  楷哥  阅读(148)  评论(0编辑  收藏  举报