OneFlow: 从 Job 到 Plan

前言

前面分析了如何从一个个 Op 变到 Job,这篇将分析如何从一个个 Job 变成一个 Plan。

Plan

首先来分析看看我们的目标是什么?我们的目标就是一个物理上可以执行的 Plan。OneFlow 在计算上的设计采用了 Actor 机制,计算图上的每个节点由一个 Actor 完成执行。那么 Plan 是如何为 Actor 机制做抽象的呢?我觉得 Actor 由计算和存储组成,计算需要考虑算子和 kernel,存储位置需要看 Regst。因此,作为 Actor 机制的上层抽象的 Plan 需要如何抽象呢?

我们来看看 Plan 这个数据结构。

  • 首先看到第一个属性,是一个 repeated TaskProto。Plan 和 TaskProto 之间的关系,就好像计算图和计算节点之间的关系一样。因此每个 Actor 的具体配置,还得看 TaskProto。
  • JobConfs,实际上是 job_id2job_conf,从 id 到具体的 Job 配置的映射。
  • 其他属性和内存管理、SBP 相关。
message Plan {
  repeated TaskProto task = 1;
  required MemBlockAndChunkList block_chunk_list = 2;
  required JobConfs job_confs = 4;
  required CollectiveBoxingPlan collective_boxing_plan= 5;
  required CtrlRegstDescInfo ctrl_regst_desc_info = 6;
  map<int64, OpAttributeRefTable> job_id2op_attribute_ref_table = 7;
}

message JobConfs {
  map<int64, JobConfigProto> job_id2job_conf = 1;
}

TaskProto

前面我们看到 TaskProto 是 Plan 中一个可以重复的属性,就好像计算图中的计算节点一样。我们需要把关注的焦点放到计算和存储。Actor 和 Actor 之间的关联不是显式的声明出来的,它隐藏在 Actor 的机制中。在一个 Actor 的输入和输出准备好了之后,就执行计算。因此,Actor 之间的关联并不需要显式声明出来,通过输入和输出关联起来即可。

message TaskProto {
  // common
  required TaskType task_type = 1;
  required int64 machine_id = 2;
  required int64 thrd_id = 3;
  required int64 task_id = 4;
  required int64 job_id = 5;
  required TaskSetInfo task_set_info = 6;
  required ExecSequence exec_sequence = 7;
  map<string, RegstDescProto> produced_regst_desc = 8;
  map<string, RegstDescIdSet> consumed_regst_desc_id = 9;
  optional bool all_register_num_eq_one_hint = 10 [default = false];
  // compute task
  optional ParallelContext parallel_ctx = 1000; // CompTask
};
  • 计算: 与之有关系的是 ExecSequence。我们跟着这样的嵌套关系看:ExecSequence -> ExecNodeProto -> KernelConf -> OpAttribute -> OperatorConf。于是我们找到了这个 Actor 上执行的 Op。
  • 存储:在 TaskProto 里面,有 produced_regst_desc 和 consumed_regst_desc_id,表示这个 Actor 存储的输入和输出。再来看 ExecNodeProto 中的属性,有一个 map,这个 map 存储了 op 中 blob name 到 regst desc id 之间的映射。
message ExecNodeProto {
  required KernelConf kernel_conf = 1;
  map<string, int64> bn_in_op2regst_desc_id = 2;
}

message ExecSequence {
  repeated ExecNodeProto exec_node = 1;
}

有个地方值得关注:在 OperatorConf 里面,有一个 op_type 属性,这个属性是一个 oneof,其中有一个是 UserOpConf,这个就是用户定义算子的配置:名字、输入、输出、属性。

message UserOpConf {
  message ListString {
    repeated string s = 1;
  }
  required string op_type_name = 1;
  map<string, ListString> input = 2;
  map<string, ListString> output = 3;
  map<string, AttrValue> attr = 4;
}

整体流程

上一篇启动 Session 的时候,已经分析过整体流程了,没有深入细节。这里再简单复述一下。

调用流程

  • StartLazyGlobalSession 由 Python 调用
  • JUST(Global::Get()->Init(job_set)) 启动全局对象 OneFlow 的初始化
  • CompileJobsAndPushMergedPlan 编译并推送 Plan
  • CompileJobsAndMergePlans 编译多个 Job 为一个 Plan

CompileJobsAndMergePlans 的主要工作如下所示:

  • 添加 Model IO Job
  • 添加 Push Job 和 Pull Job
  • CompileCurJobOnMaster 逐个编译 Job,MergeSubPlan 将 Job 合并
  • Job 之间的内存复用和内存共享
  • FinishGlobalCriticalSectionDesc 划分临界区
  • MainJob 的生成、编译、链接

后面主要关注单个 Job 的编译,MainJob 的生成、编译、链接。

编译单个 Job

  • CompileCurJobOnMaster 正如名字所表达的那样,在 Master 上编译当前 Job,非 Master 节点不编译,只需要等待 Master 发送 Plan 过来。
// oneflow/core/job/oneflow.cpp: 203
Maybe<void> CompileCurJobOnMaster(Job* job, Plan* plan, bool need_job_complete) {
  const JobDesc& job_desc = GlobalJobDesc();
  if (GlobalProcessCtx::IsThisProcessMaster()) {
    double start = GetCurTime();
    Compiler().Compile(job, plan, need_job_complete);
    PlanUtil::GenMemBlockAndChunk4Plan(plan);

    LOG(INFO) << "\njob_id: " << job_desc.job_id() << " , job_name: " << job_desc.job_name()
              << " , compile time: " << (GetCurTime() - start) / 1000000000.0 << " seconds.\n";
    if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
      TeePersistentLogStream::Create(StrCat("subplan_job_", job_desc.job_id()))->Write(*plan);
    }
  }
  PlanUtil::GenCollectiveBoxingPlan(job, plan);
  PlanUtil::GenRegisterHint(plan);
  return Maybe<void>::Ok();
}

在 Compile 这个方法里面,通过注释可以看到编译分为五步。

  • 确保 job 是 completed。
  • 创建全局 OpGraph 对象
  • 构建 TaskGraph,一开始困惑于为什么不需要传参,原来 TaskGraph 在构造函数里面,通过获取全局的 OpGraph 对象来实现初始化!
  • 将 TaskGraph 放入 Plan
  • plan 后处理和清理全局对象 OpGraph
void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const {
  // Step1: ensure job is completed.
  if (need_job_complete) { CHECK_JUST(JobCompleter().Complete(job)); }

  // Step2: new Global<OpGraph> and set log configs.
  Global<OpGraph>::New(*job);
  const JobDesc& job_desc = GlobalJobDesc();
  if (Global<ResourceDesc, ForSession>::Get()->enable_debug_mode()
      || Global<ResourceDesc, ForSession>::Get()->enable_dry_run()) {
    TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(*job);
    Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_" + std::to_string(job_desc.job_id())
                                              + "_op_graph.dot");
  }

  // Step3: build task_gph.
  // TODO(levi): we can rewrite this part of code in visitor pattern.
  auto task_gph = std::make_unique<TaskGraph>();
  using std::placeholders::_1;
  task_gph->ForEachNode(std::bind(&TaskNode::ProduceAllRegstsAndBindEdges, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::ConsumeAllRegsts, _1));
  task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1));
  task_gph->TopoForEachNode(&TaskNode::Build);
  task_gph->RemoveEmptyRegsts();
  task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain();
  auto IsReachable = Global<OpGraph>::Get()->MakePredicatorIsOpNameDataOrCtrlReachable();
  if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); }
  task_gph->TopoForEachNode(&TaskNode::InferTimeShapeIfMeaningful);
  task_gph->ForEachEdge([&](TaskEdge* task_edge) { task_edge->CheckRegstLbiValid(); });

  // Step4: put infomation from task_gph into plan.
  const int64_t node_num = task_gph->node_num();
  const int64_t cpu_num = std::thread::hardware_concurrency();
  const int64_t thread_pool_size = std::min(node_num, cpu_num);
  BlockingCounter counter(node_num);
  std::mutex mtx;
  ThreadPool thread_pool(thread_pool_size);
  task_gph->ForEachNode([&](TaskNode* task_node) {
    thread_pool.AddWork([task_node, plan, &job_desc, &counter, &mtx]() {
      if (!task_node->IsMeaningLess()) {
        TaskProto task_proto;
        task_node->ToProto(&task_proto);
        {
          std::unique_lock<std::mutex> guard(mtx);
          if (task_node->GetTaskType() == kNormalForward || task_node->GetTaskType() == kRepeat
              || task_node->GetTaskType() == kAcc) {
            CreateOpAttributeRef(plan, job_desc.job_id(), &task_proto);
          }
          plan->mutable_task()->Add(std::move(task_proto));
        }  // guard(mtx)
      }
      counter.Decrease();
    } /* thread_pool.AddWork */);
  } /* task_gph->ForEachNode */);
  counter.WaitUntilCntEqualZero();
  // NOTE(levi): release task_gph here to decrise memory peak.
  task_gph.reset();

  // Step5: post-process for plan and delete Global<OpGraph>.
  auto* job_id2job_conf = plan->mutable_job_confs()->mutable_job_id2job_conf();
  (*job_id2job_conf)[GlobalJobDesc().job_id()] = GlobalJobDesc().job_conf();
  // NOTE(chengcheng): infer mem blob id & set inplace & add ctrl
  IntraJobMemSharingUtil::InferMemBlockId4MemReusedRegst(plan, IsReachable);
  PlanUtil::SetUniqueMemBlockId4UnreusedMemRegst(plan);
  Global<OpGraph>::Delete();
}

MainJob

图片来源:https://zhuanlan.zhihu.com/p/337851255

MainJob 的作用是什么呢?

  • 负责和 Python 交互。从 Python 发送 Job ID 过来,然后启动 Job。Wait and send ids,就是等待 Python 发 id 过来,然后将 id 发送给 Reentrant lock。这个可重入锁主要用于控制并发,如果 Job ID 对应的 Job 和正在执行的 Job 没有冲突,那么就会发送给 Case,接着由 Case 分发,然后启动 Job。如果 Job ID 和正在执行的 Job 冲突了,那么这个 ID 就在 Reentrant lock 中等待。当一个 Job 执行完毕之后,就会通过 Esac 节点,发送 Job ID 回去给 Reentrant lock,然后 Reentrant lock 会放行可以执行的 Job。

生成、编译、链接

MainJob 如何来的呢?主要有三个步骤:生成、编译、链接。

生成过程

使用 JobBuilder 来构建 MainJob。执行过程大致如下:

  • 设置 ParallelConf,之后每个 Op 都是用这个 Conf
  • ReentrantLock Op
  • Case Op
  • 为每个临界区生成:source tick
  • 为每个临界区在每台设备上生成 identity tick,callback,sink tick
  • Esac Op
// oneflow/core/job/oneflow.cpp: 457
Maybe<ReentrantLockBackEdge> MakeMainJobComponent(
    const std::string& wait_and_send_ids_lbn, const Range& machine_id_range,
    JobBuilder* job_builder, std::vector<std::map<int64_t, std::string>>* identity_tick_op_names,
    std::vector<std::map<int64_t, std::string>>* cb_sink_tick_op_names) {
  ParallelConf parallel_conf;
  parallel_conf.set_device_tag("cpu");
  parallel_conf.add_device_name(std::string("@") + std::to_string(machine_id_range.begin()) + ":0");
  auto lock_back_edge = std::make_shared<ReentrantLockBackEdge>();
  OperatorConf reentrant_lock_op_conf;
  {
    lock_back_edge->reentrant_lock_op_name =
        std::string("System-Main-ReentrantLock_") + NewUniqueId();
    reentrant_lock_op_conf.set_name(lock_back_edge->reentrant_lock_op_name);
    auto* reentrant_lock_conf = reentrant_lock_op_conf.mutable_reentrant_lock_conf();
    reentrant_lock_conf->set_start(wait_and_send_ids_lbn);
    // ibn "end" is set after plan generated because we don't like cycle in job
    reentrant_lock_conf->set_out("out");
    Global<CriticalSectionDesc>::Get()->DumpCriticalSectionId2IntersectinIds(
        reentrant_lock_conf->mutable_lock_id2intersecting_lock_ids());
    JUST(job_builder->AddOp(parallel_conf, reentrant_lock_op_conf));
  }
  // critical section case op conf
  OperatorConf cs_case_op_conf;
  {
    cs_case_op_conf.set_name(std::string("System-Main-Case_") + NewUniqueId());
    auto* cs_case_conf = cs_case_op_conf.mutable_case_conf();
    cs_case_conf->set_in(reentrant_lock_op_conf.name() + "/out");
    FOR_RANGE(int64_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      cs_case_conf->add_out(GenRepeatedBn("out", i));
    }
    JUST(job_builder->AddOp(parallel_conf, cs_case_op_conf));
  }
  const int64_t num_critial_sections = Global<CriticalSectionDesc>::Get()->CriticalSectionNum();
  std::vector<std::string> snk_tick_op_names;
  FOR_RANGE(int64_t, i, 0, num_critial_sections) {
    // source tick
    OperatorConf src_tick_op_conf;
    {
      std::string name_prefix = "System-Main-SourceTick_CriticalSection_";
      src_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
      auto* src_tick_conf = src_tick_op_conf.mutable_tick_conf();
      src_tick_conf->add_tick(cs_case_op_conf.name() + "/" + GenRepeatedBn("out", i));
      src_tick_conf->set_out("out");
      JUST(job_builder->AddOp(parallel_conf, src_tick_op_conf));
    }

    auto* cur_cb_sink_tick_op_names = &cb_sink_tick_op_names->at(i);
    for (int64_t machine_id = machine_id_range.begin(); machine_id < machine_id_range.end();
         ++machine_id) {
      // identity tick
      OperatorConf identity_tick_op_conf;
      {
        std::string name_prefix = "System-Main-Tick_CriticalSection_";
        identity_tick_op_conf.set_name(name_prefix + std::to_string(i) + "_" + NewUniqueId());
        auto* identity_tick_conf = identity_tick_op_conf.mutable_tick_conf();
        identity_tick_conf->add_tick(src_tick_op_conf.name() + "/out");
        identity_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, identity_tick_op_conf));
        auto* cur_id_tick_op_names = &identity_tick_op_names->at(i);
        CHECK_OR_RETURN(
            cur_id_tick_op_names->emplace(machine_id, identity_tick_op_conf.name()).second);
      }
      // callback
      {
        OperatorConf cb_sink_tick_op_conf;
        std::string name_prefix = "System-Main-CallbackSinkTick_";
        cb_sink_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* cb_sink_tick_conf = cb_sink_tick_op_conf.mutable_sink_tick_conf();
        cb_sink_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        cb_sink_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, cb_sink_tick_op_conf));
        CHECK_OR_RETURN(
            cur_cb_sink_tick_op_names->emplace(machine_id, cb_sink_tick_op_conf.name()).second);
      }
      // sink tick
      {
        OperatorConf snk_tick_op_conf;
        std::string name_prefix = "System-Main-SinkTick_CriticalSection_";
        snk_tick_op_conf.set_name(name_prefix + std::to_string(i) + NewUniqueId());
        auto* snk_tick_conf = snk_tick_op_conf.mutable_sink_tick_conf();
        snk_tick_conf->add_tick(identity_tick_op_conf.name() + "/out");
        snk_tick_conf->set_out("out");
        JUST(job_builder->AddOp(parallel_conf, snk_tick_op_conf));
        snk_tick_op_names.push_back(snk_tick_op_conf.name());
      }
    }
  }
  // critical section esac op conf
  OperatorConf cs_esac_op_conf;
  {
    cs_esac_op_conf.set_name(std::string("System-Main-Esac_") + NewUniqueId());
    // cs_esac_op_conf.set_pass_tag("main");
    auto* cs_esac_conf = cs_esac_op_conf.mutable_esac_conf();
    for (const auto& snk_tick_op_name : snk_tick_op_names) {
      cs_esac_conf->add_in(snk_tick_op_name + "/out");
    }
    cs_esac_conf->set_out("out");
    cs_esac_conf->set_data_type(DataType::kInt32);
    JUST(job_builder->AddOp(parallel_conf, cs_esac_op_conf));
  }
  lock_back_edge->critical_section_sink_lbi.set_op_name(cs_esac_op_conf.name());
  lock_back_edge->critical_section_sink_lbi.set_blob_name("out");
  return lock_back_edge;
}

编译

同样只在 Master 上面编译。设置编译的 scope,然后编译。

// oneflow/core/job/oneflow.cpp: 732
Maybe<void> CompileMainJob(Job* main_job, const std::vector<ReentrantLockBackEdge>& lock_back_edges,
                           int64_t job_id, Plan* main_plan) {
  CHECK_OR_RETURN(GlobalProcessCtx::IsThisProcessMaster());
  {
    auto scope = std::make_unique<GlobalJobDescScope>(main_job->job_conf(), job_id);
    JUST(CompileCurJobOnMaster(main_job, main_plan, false));
  }
  for (const auto& lock_back_edge : lock_back_edges) {
    JUST(ConnectCriticalSectionEndToReentrantLockEnd(main_plan, lock_back_edge));
  }
  return Maybe<void>::Ok();
}

链接

从效果来看,将所有其他的 Job 的临界区,加入到 Main Plan 里面,构成一个大的 Plan。

// oneflow/core/job/oneflow.cpp: 306
void LinkMainPlan(Plan* plan, Plan&& main_plan,
                  const std::vector<std::map<int64_t, std::string>>& identity_tick_op_names) {
  std::function<bool(const TaskProto*)> IsInterfaceTickTockTask;
  {
    auto task_ids = std::make_shared<HashSet<int64_t>>();
    for (const auto& task : main_plan.task()) {
      if (task.task_type() == TaskType::kTick) { CHECK(task_ids->emplace(task.task_id()).second); }
    }
    IsInterfaceTickTockTask = [task_ids, plan](const TaskProto* task) {
      if (task_ids->find(task->task_id()) != task_ids->end()) { return true; }
      if (task->exec_sequence().exec_node_size() != 1) { return false; }
      const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
      OperatorConf::OpTypeCase op_type_case =
          PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().op_type_case();
      return op_type_case == OperatorConf::kSourceTickConf
             || op_type_case == OperatorConf::kSinkTickConf;
    };
  }
  MergePlan(plan, std::move(main_plan));
  HashMap<std::string, TaskProto*> sole_tick_op_name2sole_task;
  FOR_RANGE(int64_t, i, 0, plan->task_size()) {
    TaskProto* task = plan->mutable_task(i);
    if (IsInterfaceTickTockTask(task) == false) { continue; }
    const auto& kernel_conf = task->exec_sequence().exec_node(0).kernel_conf();
    const auto& op_name =
        PlanUtil::GetOpAttribute(plan, task->job_id(), kernel_conf).op_conf().name();
    CHECK(sole_tick_op_name2sole_task.emplace(op_name, task).second);
  }
  auto TaskProto4TaskId = PlanUtil::MakeGetterTaskProto4TaskId(*plan);
  const auto& process_ranks = Global<ResourceDesc, ForSession>::Get()->process_ranks();
  FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
    const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
    for (int64_t machine_id : process_ranks) {
      TaskProto* identity_tick =
          sole_tick_op_name2sole_task.at(identity_tick_op_names.at(i).at(machine_id));
      LinkTickTaskProto(
          plan, identity_tick,
          sole_tick_op_name2sole_task.at(cs.machine_id2source_tick_op_name().at(machine_id)),
          sole_tick_op_name2sole_task.at(cs.machine_id2sink_tick_op_name().at(machine_id)));
    }
  }
  {
    // erase source_tick task_proto
    HashSet<std::string> source_tick_op_names;
    FOR_RANGE(int32_t, i, 0, Global<CriticalSectionDesc>::Get()->CriticalSectionNum()) {
      const CriticalSection& cs = Global<CriticalSectionDesc>::Get()->GetCriticalSection(i);
      for (int64_t machine_id : process_ranks) {
        const auto& src_tick_op_name = cs.machine_id2source_tick_op_name().at(machine_id);
        CHECK(source_tick_op_names.emplace(src_tick_op_name).second);
      }
    }
    Erase<PbRpf<TaskProto>>(*plan->mutable_task(), [&](const TaskProto& task) {
      if (task.task_type() == TaskType::kSourceTick) {
        CHECK(task.exec_sequence().exec_node_size() == 1);
        const auto& kernel_conf = task.exec_sequence().exec_node(0).kernel_conf();
        const auto& op_conf = PlanUtil::GetOpAttribute(plan, task.job_id(), kernel_conf).op_conf();
        CHECK(op_conf.has_source_tick_conf());
        CHECK(source_tick_op_names.find(op_conf.name()) != source_tick_op_names.end());
        return true;
      } else {
        return false;
      }
    });
  }
}
posted @ 2021-09-04 15:56  楷哥  阅读(283)  评论(0编辑  收藏  举报