[源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark

[源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark

0x00 摘要

Horovod 是Uber于2017年发布的一个易于使用的高性能的分布式训练框架,在业界得到了广泛应用。

本系列将通过源码分析来带领大家了解 Horovod。本文是系列第十篇,看看horovod 如何运行在 spark 之上。

Horovod on Spark 具体有两种底层实现:MPI,GLOO。因为篇幅所限,本文介绍 MPI 实现,下一篇介绍GLOO实现。

本系列其他文章如下:

[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么

[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver

[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台线程架构

[源码解析] 深度学习分布式训练框架 horovod (7) --- DistributedOptimizer

[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

0x01 回顾

1.1 总体序列图

接上文,我们首先要回顾下 Horovod on Spark 的总体序列图,这样脑子里有一个全景,温故而知新。

img

1.2 总体逻辑

总体来说,Horovod on Spark 的总体逻辑分为以下阶段:

  • 启动 SparkDriverService 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束;
  • 多线程在 spark executor 之中启动 spark task,每个task之中运行一个 SparkTaskService,SparkTaskService 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;
  • Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段;
  • Horovod 调用 mpi_run (又利用到 mpirun_rsh.py)在每一个 spark executor 上启动 orted(这里是通过 SparkTaskService 来启动 orted),以启动 MPI cluster;
  • orted 在每一个 executor 之上运行训练代码;

前文已经分析了前面三个阶段,本文继续后面两个阶段的分析。

1.3 问题

结合上面的流程图,这里就有一个问题会令人疑惑。

Horovod 按说可以直接调用 mpirun 来在远端启动 orted(orted 就是 mpi 可执行程序。mpirun 是 orterun 的别名,而 ortedrun 会最终调用到 orted)。但是为什么流程图上不是直接调用,而是通过 mpirun_rsh.py,进而通过 SparkTaskService 来启动 orted?

原因应该是:

  • 通常 MPI 会通过 SSH 来连接 hosts,但是这种方式无法在 Spark Executor 之中启动 Python function。
  • Orted 需要运行在 Spark Executor 之中,但是 mpirun 在启动时候,没办法知道 Spark Executor 的 IP : PORT 这个组合,所以没法直接启动。
  • 因此 MPI 使用RPC 来启动用户代码:
    • 通过 SparkDriverService 和 SparkTaskService 等交互才可以知道这个 IP : PORT 组合信息。
    • 使用 horovod.spark.driver.mpirun_rsh 来连接每个 Executor,然后 "remote shell" 到这些 executors 之中。
    • 直接使用 SparkTaskService 来启动 orted。

0x02 第四阶段 : 启动 Job

下面我们看看第四阶段,就是如何运行 训练 job。

2.1 _launch_job

_launch_job 很简单:

  • 首先 driver.get_common_interfaces 获取网络路由信息;
  • 其次 调用 run_contoller 来启动 job;
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    nics = driver.get_common_interfaces()
    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env, stdout, stderr),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)

2.2 获取路由信息

get_common_interfaces 与普通模式下的 get_common_interfaces 不同。因为此时,Spark Executor 之中的 SparkTaskService 的信息已经保存在 Driver 之中,直接获取即可。

def get_common_interfaces(self):
    if self._nics is not None:
        return self._nics

    nics = None
    if len(self._task_addresses_for_tasks) > 0:
        # in Elastic Horovod on Spark with auto-scaling
        # keys in task_addresses are in range(max_np or proc_num)
        # but not all keys may exist, so we don't do for index in range(proc_num)
        indices = list(self._task_addresses_for_tasks.keys())
        nics = set(self._task_addresses_for_tasks[indices[0]].keys()) # 直接获取
        for index in indices[1:]:
            nics.intersection_update(self._task_addresses_for_tasks[index].keys())

    return nics

2.3 run_controller

就是依据配置和编译情况来进行处理,选择 gloo,js,还是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run()
    elif use_mpi:
        mpi_run()
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run()
        elif gloo_built(verbose=verbose):
            gloo_run()

所以我们开始启动 job,具体我们分为 MPI,GLOO两种进行分析。

0x03 MPI 实验

我们首先要做一些 MPI 相关实验,其原因是因为:

  • MPI 的调用之中有些看起来很奇怪的行为,或者说是一些 trick。
  • 这些 trick 对于 "horovod on spark" 基于 MPI 的实现是很有帮助,但是对于理解代码却是一个极大的干扰
  • 我们暂时没有时间和精力去研究 MPI 的源码是如何实现的,因为已经超出了本文范畴。

所以我们只能针对某些奇怪的行为,对 MPI 的相关实现机制做一些假设和估计。然后通过一个简单的实验来验证我们的假设。

3.1 问题点

我们执行的 mpi 命令格式如下,这个命令格式就是为了模拟Horovod的 MPI 命令:

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

问题点就是:

  • plm_rsh_agent "python rsh.py" 的作用是什么?
  • rsh.py 之中,有哪些 trick?如何调用远程 mpi 程序?
  • python user_function.py 是在 rsh.py 之后运行吗?

3.2 名词解释

3.2.1 orterun & orted

最开始在看到这个命令时候,容易让人很晕。因为代码中没有任何提及。

其实,outed 就是 mpi 可执行程序。

mpirun 是 orterun 的别名,而 ortedrun 会最终调用到 orted

具体解释如下,信息来源为 http://cn.voidcc.com/question/p-wkloammx-bha.html:

mpirunmpiexec基本上是相同的 - 许多MPI实现中的进程启动器的名称。 MPI标准没有提到如何启动和控制等级,但它建议(尽管不要求),如果有任何类型的启动器,它应该被命名为mpiexec。一些MPI实现以mpirun开始,然后采用mpiexec以实现兼容性。其他实现则相反。最后,大多数实现都使用两个名称来提供它们的启动器。在实践中,mpirunmpiexec所做的事情应该没有什么不同。

不同的MPI实现有不同的启动和控制过程的方法。 MPICH从一个名为MPD(多用途守护进程或其他)的基础架构开始。然后切换到新的Hydra流程管理器。由于Hydra的功能与MPD不同,因此基于Hydra的mpiexec采用的命令行参数不同于基于MPD的命令行参数,并且使用户可以明确选择基于Hydra的命令行参数,因此它可用作mpiexec.hydra。旧的称为mpiexec.mpd。可能有一个基于MPICH的MPI库只提供Hydra启动程序,然后mpiexecmpiexec.hydra将是相同的可执行文件。英特尔MPI基于MPICH,其新版本使用Hydra进程管理器。

Open MPI建立在开放运行环境(ORTE)的基础上,其自身的进程启动器被称为orterun。为了兼容,orterun也符号链接为mpirunmpiexec

总结:

  • mpiexec.something是MPI进程启动的给定实现的特定版本
  • mpiexecmpirun是通用名称的符号链接到实际发射通常副本或
  • mpiexecmpirun应该这样做
  • 某些实现命名他们的发射器mpiexec,有些人命名它mpirun,有人将其命名为两者,当系统路径中同时有多个MPI实现可用时,这通常是混淆的来源(例如,当从发行版安装时)

3.2.2 mpi orterun 源码

mpi之中 orterun 对应的源码如下,最主要是调用了 orte_submit_job 提交 job。

int orterun(int argc, char *argv[])
{
    orte_submit_status_t launchst, completest;

    /* orte_submit_init() will also check if the user is running as
       root (and may issue a warning/exit). */
    if (ORTE_SUCCESS != orte_submit_init(argc, argv, NULL)) {
        exit(1);
    }

    /* setup to listen for commands sent specifically to me, even though I would probably
     * be the one sending them! Unfortunately, since I am a participating daemon,
     * there are times I need to send a command to "all daemons", and that means *I* have
     * to receive it too
     */
    orte_rml.recv_buffer_nb(ORTE_NAME_WILDCARD, ORTE_RML_TAG_DAEMON,
                            ORTE_RML_PERSISTENT, orte_daemon_recv, NULL);

    /* if the user just wants us to terminate a DVM, then do so */
    if (orte_cmd_options.terminate_dvm) {
        // 省略部分代码
    } else {
        /* spawn the job and its daemons */
        memset(&launchst, 0, sizeof(launchst));
        memset(&completest, 0, sizeof(completest));
        launchst.active = true;
        completest.active = true;
      
        // 在这里进行提交 job
        if (ORTE_SUCCESS != orte_submit_job(argv, NULL,
                                            launched, &launchst,
                                            completed, &completest)) {
            ORTE_UPDATE_EXIT_STATUS(1);
            goto DONE;
        }
    }

    // wait for response and unpack the status, jobid
    // 省略部分代码
}

3.3 实验设计

3.3.1 组件

有如下几个组件,其作用分别如下:

  • host 文件。作用是指定本次运行有哪些host,以及host之上运行几个MPI进程。
  • rsh.py。作用是作为 rsh agent 来给远端机器下达命令。
    • MPI 用户也可以通过其他方式给远程机器下发命令。
    • 用户可以对每个主机使用远程 shell(sshrsh)而无需登录主机。默认情况下,mpirun 使用 ssh
    • 如果 mpirun 使用 ssh 出现问题,可以尝试在 mpirun 命令中使用 --mca plm_rsh_agent rsh 选项,以使用 rsh 命令进行连接。
  • user_function.py。就是用户希望执行的函数。

3.3.2 host 文件 remote_hosts

remote_hosts 文件内容如下:

1.1.1.1:2
2.2.2.2:2

其意义是:

  • 1.1.1.1 这个 ip 运行 2 个 slot,即两个 MPI 进程。
  • 2.2.2.2 这个 ip 运行 2 个 slot,即两个 MPI 进程。

3.3.3 rsh.py

rsh.py 内容如下,作用就是打印 MPI 传入的 command,然后在远端host之上启动的 MPI 进程中运行新命令:

import os
import sys
import subprocess

if __name__ == '__main__':
  command = " ".join(sys.argv[0:])
  print(command)
  new_command = " ".join(sys.argv[2:])
  print(new_command)
  subprocess.Popen(new_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

3.3.4 user_function.py

内容如下,就是为了测试,打印一条语句。

print('hello world')

3.4 实验结果

我们在 1.1.1.1 之上运行 mpi 命令。

mpirun --allow-run-as-root -n 4 --hostfile ./remote_hosts -mca plm_rsh_agent "python rsh.py" python user_function.py

结果如下:

# 以下是 command 内容,就是 MPI 传递给 rsh.py 的内容,这里居然有 plm_rsh_agent "python rsh.py" 
rsh.py 1.1.1.1 orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 new_command 内容,就是 在远端host上执行 用户代码 的方法,这里居然有 plm_rsh_agent "python rsh.py" 
orted -mca ess "env" -mca ess_base_jobid "41114481152" -mca ess_base_vpid 1 -mca ess_base_num_procs "4" -mca ored_node_regex "ip-[2:1]-1-1-1,[2:1]2.2.2@0(2)" -mca orted_hnp_uri "41114481152.0,tcp://1.1.1.1:53405" -mca plm "rsh" --tree-spawn -mca orte_parent_uri "41114481152.0,tcp://1.1.1.1:53405"  -mca plm_rsh_agent "python rsh.py" -mca pmix "^s1,s2,cray,isolated"

# 以下是 user_function.py 的执行内容
hello world

因此我们知道

  • plm_rsh_agent "python rsh.py" 的作用是在远端运行 MPI orted。
  • python user_function.py 是在 rsh 之后运行的,而且是在远端的 orted 之中运行。
  • 在 rsh.py 执行过程中,其接受到的命令内容有些奇怪

3.5 运行过程

运行过程如下:

  1. mpirun 运行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py,此时在远端会运行一个 MPI daemon,用来响应处理;
  2. mpirun 调用了 rsh.py;
  3. rsh.py 使用 subprocess(orted -mca plm_rsh_agent "python rsh.py") 在远端启动 orted(会与 daemon 沟通),运行用户代码;

具体如下图:

                                                         1.1.1.1        +          2.2.2.2
                                                                        |
                                                                        |
                                                                        |  1      +---------------+
mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py  +----------->  |  MPI deamon   |
                                                                        |         +-------+-------+
             +                                                          |                 |
             |                                                          |                 |
             | 2                                                        |                 |
             |                                                          |                 |  3
             |                                                          |                 |
             |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |                 |
             |                                                          |                 v
             |                                                          |         +-------+--------------------------+
             |                                                          |         | orted                            |
             |                                                          |         |                                  |
             v                                                          |         |                                  |
+------------+------------------------------------------------------+   |         |   +---------------------------+  |
| rsh.py                                                            |   |         |   | user_function.py          |  |
|                                                                   |   |         |   |                           |  |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"        |   |         |   |                           |  |
|                                                                   |   |   3     |   |      print('hello world') |  |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py") +-------------------->  |   |                           |  |
|                                                                   |   |         |   +---------------------------+  |
+-------------------------------------------------------------------+   +         +----------------------------------+

手机如下:

3.6 Trick 分析

我们发现有几个奇怪的点:

  • mpirun 运行 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py
  • mpirun 调用了 rsh.py,但是在 rsh.py 收到的 argv 中,居然也有 plm_rsh_agent "python rsh.py" 。按说这时候不应该有这个参数了,因为 rsh.py 已经调用了,就不应该再有这个参数
  • rsh.py 运行远端 MPI,使用的是 orted -mca plm_rsh_agent "python rsh.py",这里居然还有 plm_rsh_agent "python rsh.py" 这个参数。这时候也不应该,因为 orted 已经运行在远端了,这时候也传入一个用来远端控制的 rsh agent 参数,太奇怪了

就是说plm_rsh_agent "python rsh.py" 这个参数居然被 MPI 传递到各个阶段,无论是 rsh agent 或者 远端 mpi

rsh agent 就是 trick。不知道 MPI 为什么要把 plm_rsh_agent "python rsh.py" 在各个阶段传递的意图,可能是为了更好的控制。

因为没有精力来分析 MPI 源码,所以初步判断,远端 MPI daemon 在运行 orted -mca plm_rsh_agent "python rsh.py"时候,会判断是否已经是远端,如果是远端,就不再运行 rsh agent 了。

所以,我们在后面分析中,在 Spark task 之中 发现 类似 plm_rsh_agent "python rsh.py" ,就不用再疑惑了

0x04 MPI 实现

一般来说,Horovod on Spark 是以 MPI 模式来运行,所以我们重点看这里。

4.1 mpi_run in spark

mpi_run 代码位于:horovod/spark/mpi_run.py,作用是:

  • 依据各种配置生成remote shell的agent;
  • 依据各种配置生成可执行命令;
  • 调用hr_mpi_run(horovod.runner.mpi_run 就是普通模式下的 mpi_run)运行命令;

比如得到 rsh_agent 大致如下:

("/usr/bin/python", "-m", "horovod.spark.driver.mpirun_rsh", "xxxxx", "yyy")

得到 command 大致如下:

("/usr/bin/python", "-m", "horovod.spark.task.mpirun_exec_fn", "xxxxx", "yyy")

具体代码如下:

from horovod.runner.mpi_run import mpi_run as hr_mpi_run

def mpi_run(settings, nics, driver, env, stdout=None, stderr=None):
    """
    Runs mpirun.

    :param settings: Settings for running MPI.
                     Note: settings.num_proc and settings.hosts must not be None.
    :param nics: Interfaces to include by MPI.
    :param driver: The Spark driver service that tasks are connected to.
    :param env: Environment dictionary to use for running MPI.  Can be None.
    :param stdout: Stdout of the mpi process.
                   Only used when settings.run_func_mode is True.
    :param stderr: Stderr of the mpi process.
                   Only used when settings.run_func_mode is True.
    """
    env = {} if env is None else copy.copy(env)  # copy env so we do not leak env modifications

    # Pass secret key through the environment variables.
    env[secret.HOROVOD_SECRET_KEY] = codec.dumps_base64(settings.key)
    # we don't want the key to be serialized along with settings from here on
    settings.key = None

    # 拼接出rsh_agent
    rsh_agent = (sys.executable,
                 '-m', 'horovod.spark.driver.mpirun_rsh',
                 codec.dumps_base64(driver.addresses()),
                 codec.dumps_base64(settings))
    settings.extra_mpi_args = ('{extra_mpi_args} -x NCCL_DEBUG=INFO -mca plm_rsh_agent "{rsh_agent}"'
                               .format(extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                                       rsh_agent=' '.join(rsh_agent)))
    # 拼接出command
    command = (sys.executable,
               '-m', 'horovod.spark.task.mpirun_exec_fn',
               codec.dumps_base64(driver.addresses()),
               codec.dumps_base64(settings))
    hr_mpi_run(settings, nics, env, command, stdout=stdout, stderr=stderr)

4.2 mpi_run in normal

上面代码最后是运行 hr_mpi_run,其实 hr_mpi_run 是 horovod.runner.mpi_run,就是普通模式下的 mpi_run。

horovod.runner.mpi_run 首先 就是依据各种配置以及参数来构建 mpirun 命令的所有参数,比如 ssh 的参数,mpi 参数,nccl 参数等等。

得到了 command 大致如下:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

具体代码如下:

def mpi_run(settings, nics, env, command, stdout=None, stderr=None):
    """
    Runs mpi_run.

    Args:
        settings: Settings for running MPI.
                  Note: settings.num_proc and settings.hosts must not be None.
        nics: Interfaces to include by MPI.
        env: Environment dictionary to use for running command.
        command: Command and arguments to run as a list of string.
        stdout: Stdout of the mpi process.
                Only used when settings.run_func_mode is True.
        stderr: Stderr of the mpi process.
                Only used when settings.run_func_mode is True.
    """

    # 获取mpi相关配置
    mpi_impl_flags, impl_binding_args, mpi = _get_mpi_implementation_flags(settings.tcp_flag, env=env)
    impi = _IMPI_IMPL == mpi

    # 获取ssh配置
    ssh_args = []
    if settings.ssh_port:
        ssh_args += [f'-p {settings.ssh_port}']
    if settings.ssh_identity_file:
        ssh_args += [f'-i {settings.ssh_identity_file}']

    mpi_ssh_args = ''
    if ssh_args:
        joined_ssh_args = ' '.join(ssh_args)
        mpi_ssh_args = f'-bootstrap=ssh -bootstrap-exec-args \"{joined_ssh_args}\"' if impi else f'-mca plm_rsh_args \"{joined_ssh_args}\"'

    # 网卡相关信息
    tcp_intf_arg = '-mca btl_tcp_if_include {nics}'.format(
        nics=','.join(nics)) if nics and not impi else ''
    nccl_socket_intf_arg = '-{opt} NCCL_SOCKET_IFNAME={nics}'.format(
        opt='genv' if impi else 'x',
        nics=','.join(nics)) if nics else ''

    # On large cluster runs (e.g. Summit), we need extra settings to work around OpenMPI issues
    host_names, host_to_slots = hosts.parse_hosts_and_slots(settings.hosts)
    if not impi and host_names and len(host_names) >= _LARGE_CLUSTER_THRESHOLD:
        mpi_impl_flags.append('-mca plm_rsh_no_tree_spawn true')
        mpi_impl_flags.append('-mca plm_rsh_num_concurrent {}'.format(len(host_names)))

    # if user does not specify any hosts, mpirun by default uses local host.
    # There is no need to specify localhost.
    hosts_arg = '-{opt} {hosts}'.format(opt='hosts' if impi else 'H',
                hosts=','.join(host_names) if host_names and impi else settings.hosts)

    ppn_arg = ' '
    if host_to_slots and impi:
        ppn = host_to_slots[host_names[0]]
        for h_name in host_names[1:]:
            if ppn != host_to_slots[h_name]:
                raise Exception('''Different slots in -hosts parameter are not supported in Intel(R) MPI.
                                 Use -machinefile <machine_file> for this purpose.''')
        ppn_arg = ' -ppn {} '.format(ppn)

    if settings.prefix_output_with_timestamp and not impi:
        mpi_impl_flags.append('--timestamp-output')

    binding_args = settings.binding_args if settings.binding_args and not impi else ' '.join(impl_binding_args)

    basic_args = '-l' if impi else '--allow-run-as-root --tag-output'

    output = []
    if settings.output_filename:
        output.append('-outfile-pattern' if impi else '--output-filename')
        output.append(settings.output_filename)

    env_list = '' if impi else ' '.join(
                    '-x %s' % key for key in sorted(env.keys()) if env_util.is_exportable(key))

    # Pass all the env variables to the mpirun command.
    mpirun_command = (
        'mpirun {basic_args} '
        '-np {num_proc}{ppn_arg}{hosts_arg} '
        '{binding_args} '
        '{mpi_args} '
        '{mpi_ssh_args} '
        '{tcp_intf_arg} '
        '{nccl_socket_intf_arg} '
        '{output_filename_arg} '
        '{env} {extra_mpi_args} {command}'  # expect a lot of environment variables
        .format(basic_args=basic_args,
                num_proc=settings.num_proc,
                ppn_arg=ppn_arg,
                hosts_arg=hosts_arg,
                binding_args=binding_args,
                mpi_args=' '.join(mpi_impl_flags),
                tcp_intf_arg=tcp_intf_arg,
                nccl_socket_intf_arg=nccl_socket_intf_arg,
                mpi_ssh_args=mpi_ssh_args,
                output_filename_arg=' '.join(output),
                env=env_list,
                extra_mpi_args=settings.extra_mpi_args if settings.extra_mpi_args else '',
                command=' '.join(quote(par) for par in command))
    )

    # we need the driver's PATH and PYTHONPATH in env to run mpirun,
    # env for mpirun is different to env encoded in mpirun_command
    for var in ['PATH', 'PYTHONPATH']:
        if var not in env and var in os.environ:
            # copy env so we do not leak env modifications
            env = copy.copy(env)
            # copy var over from os.environ
            env[var] = os.environ[var]

    # Execute the mpirun command.
    if settings.run_func_mode:
        exit_code = safe_shell_exec.execute(mpirun_command, env=env, stdout=stdout, stderr=stderr)
    else:
        os.execve('/bin/sh', ['/bin/sh', '-c', mpirun_command], env)

4.3 执行命令

目前得到的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

所以我们接着分析。

当 mpi_run 准备好命令之后,他调用 safe_shell_exec.execute 或者 bin/sh 执行命令。对于 safe_shell_exec.execute 来说,它需要执行的命令是:

mpirun --allow-run-as-root --map-by slot -x SSH_CONNECITION -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" /usr/bin/python -m horovod.spark.task.mpurun_exec_fn xxxxx

这样,就是先调用 safe_shell_exec.execute 或者 bin/sh 执行 "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx",然后执行 horovod.spark.task.mpurun_exec_fn xxxxx。

4.3.1 mpi 参数

对于 mpirun 来说,参数 --mca pls_rsh_agent rsh 告诉节点间通讯用rsh。

这样我们就知道 horovod.spark.driver.mpirun_rsh 就是在节点通讯时候,首先执行的脚本。

就是说,当 mpirun 想在异地节点运行一个 程序(horovod.spark.task.mpurun_exec_fn) 时候,首先运行 horovod.spark.driver.mpirun_rsh 从而在异地节点上启动一个 orted,其次在这个 异地 orted 之上运行 horovod.spark.task.mpurun_exec_fn

4.3.3 mpirun_rsh.py

所以,horovod.spark.driver.mpirun_rsh 会最先运行,我们需要首先看看,就是下图中最下面部分

mpirun_rsh.py 的作用如其注释所言,目的是被 MPI 调用以便连接到一个 host,并且执行指定的命令

命令通常是 orted ,用来创建 MPI cluster。orted 进程然后被用来启动远端进程(Horovod 用户的 Python方法)。 orted 进程将运行在最低index的 task上,同一个host 的其他task将执行 no-op 并且等待 orted task 结束

Method run by MPI to connect to a host hash and execute the given command.

The command is usually `orted` to setup the MPI cluster. That `orted` process
is then used to spin-up the actual remote process, the Horovod user's Python method.
The `orted` process will run on the lowest task index and all other tasks with the
same host hash are expected to no-op (see `horovod.spark._task_fn`)
and wait for the first task to terminate.

但是实际上代码其实很简单,就是直接调用了 rsh,所以我们还得接着看。

if len(sys.argv) < 5:
    print('Usage: %s <service addresses> <settings> <host hash> '
          '<command...>' % sys.argv[0])
    sys.exit(1)

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose) # 直接调用

4.3.4 rsh

这里才是上述逻辑的具体实现,所以rsh 的作用就是:

  • 与在 Spark Driver 上运行的 SparkDriverService 进行交互,从 SparkDriverService 获取需要运行 task 的所需信息;
  • 与 Spark Executor 中的 SparkTaskService 交互,运行 command;

具体到代码就是:

  • 利用 driver_client.task_host_hash_indices(host_hash) 从在 Spark Driver 上运行的 SparkDriverService 获取某一个 host 上的所有 task;
  • 利用 task_indices[local_rank] 获取到对应的 task;
  • 利用 driver_client.all_task_addresses(task_index) 获取 task 的地址;
  • 利用 task_service.SparkTaskClient.run_command 来运行 command;

command 举例如下,此时 command 已经被 mpirun 处理转义

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

具体代码是:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)

        try:
            exit_code = task_client.wait_for_command_exit_code()
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()

4.3.5 发送命令

具体 run_command 就是向 SparkTaskService 发送 RunCommandRequest。

class BasicTaskClient(network.BasicClient):

  def run_command(self, command, env,
                    capture_stdout=False, capture_stderr=False,
                    prefix_output_with_timestamp=False):
        self._send(RunCommandRequest(command, env,
                                     capture_stdout, capture_stderr,
                                     prefix_output_with_timestamp))

具体如下图逻辑所示:

与之前的测试代码对比如下:

                                                   Our test code     +    Horovod on spark
                                                                     |
                                                                     |
 mpirun -mca plm_rsh_agent "python rsh.py" python user_function.py   |    mpirun pls_rsh_agent "python mpirun_rsh" python -m mpurun_exec_fn
                                                                     |
          +                                                          |           +
          |                                                          |           |
          |  rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py" |           |    orted -mca pls_rsh_agent "python -m mpirun_rsh"
          |                                                          |           |
          v                                                                      v
+----------------------------------------------------------------+   |    +------+---------------------------------------------------+
| rsh.py (via SSH)                                               |   |    | mpirun_rsh                                               |
|                                                                |   |    |                                                          |
|    rsh.py 1.1.1.1 orted -mca plm_rsh_agent "python rsh.py"     |   |    +------+---------------------------------------------------+
|                                                                |   |           |
|                                                                |   |           |
|                                                                |   |           v
|                                                                |   |    +----------------------------------------------------------+
|                                                                |   |    | rsh (via RPC)                                            |
|                                                                |   |    |                                                          |
|    subprocess(orted -mca plm_rsh_agent "python rsh.py")        |   |    |                                                          |
|                                                                |   |    |  task_client = task_service.SparkTaskClient              |
|                                                                |   |    |                                                          |
|                                                                |   |    |  task_client.run_command(                                |
|                                                                |   |    |       orted -mca pls_rsh_agent "python -m mpirun_rsh"    |
|                                                                |   |    |  )                                                       |
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
          |                                                          |           |
          |                                                          |           |
          v                                                          |           v
+---------+------------------------------------------------------+   |    +------+---------------------------------------------------+
| user_function.py                                               |   |    | mpirun_exec_fn.py                                        |
|                                                                |   |    |                                                          |
|    print('hello world')                                        |   |    |              task_exec +--------> user_function          |
|                                                                |   |    |                                                          |
+----------------------------------------------------------------+   |    +----------------------------------------------------------+
                                                                     +

手机如下:

因此,下面就会进入到 spark executor 去运行

4.4 Run in Spark Executor

再次注意,这里已经是 远端的 Spark Executor 了

上节提到,系统会利用 task_service.SparkTaskClient.run_command 来运行command;

command 举例如下,此时 command 已经被 mpirun 处理转义:

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

需要依据上图留意一点:系统在 Spark Executor 上执行 command 之后,会接着运行 mpirun_exec_fn

我们接下来就看看如何处理 RunCommandRequest。具体都是在 BasicTaskService 之中完成。

4.4.1 RunCommandRequest

可以看到,接受到消息之后,是调用 _run_command 完成。

def _handle(self, req, client_address):
    if isinstance(req, RunCommandRequest):
        self._wait_cond.acquire()
        try:
            if self._command_thread is None:
                # we add req.env to _command_env and make this available to the executed command
                if self._command_env:
                    env = self._command_env.copy()
                    self._add_envs(env, req.env)
                    req.env = env

                # We only permit executing exactly one command, so this is idempotent.
                self._command_abort = threading.Event()
                self._command_stdout = Pipe() if req.capture_stdout else None
                self._command_stderr = Pipe() if req.capture_stderr else None
                args = (req.command, req.env, self._command_abort,
                        self._command_stdout, self._command_stderr,
                        self._index,
                        req.prefix_output_with_timestamp)
                self._command_thread = in_thread(self._run_command, args)
        finally:
            self._wait_cond.notify_all()
            self._wait_cond.release()
        return network.AckResponse()

4.4.2 _run_command

_run_command 就是调用 safe_shell_exec.execute 直接运行。

def _run_command(self, command, env, event,
                 stdout=None, stderr=None, index=None,
                 prefix_output_with_timestamp=False):
    self._command_exit_code = safe_shell_exec.execute(
        command,
        env=env,
        stdout=stdout, stderr=stderr,
        index=index,
        prefix_output_with_timestamp=prefix_output_with_timestamp,
        events=[event])
    if stdout:
        stdout.close()
    if stderr:
        stderr.close()

因此,接下来就是在 Spark Executor 之中,开始执行

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

注意,此时是在 Spark Executor 之中,所以接下来行为会和之前不同。

4.4.3 mpirun_rsh

mpirun_rsh 依然是调用 rsh。

addresses = codec.loads_base64(sys.argv[1])
key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
settings = codec.loads_base64(sys.argv[2])
host_hash = sys.argv[3]
command = " ".join(sys.argv[4:])
env = {}  # orted does not need any env vars, the target training code gets env from mpirun

# Since tasks with the same host hash have shared memory,
# we will run only one orted process on the first task.
rsh(addresses, key, host_hash, command, env, 0, settings.verbose)

4.4.4 rsh

代码如下:

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

但是此时运行就出现了与之前的不同之处

此时在 Spark Executor 再次调用

/usr/local/bin/orted -mca ess "env" -mcc ess_base_num_procs "2" -mca orte_hnp_uri "xxxx" -mca pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx"

回忆一下 0x03 MPI 实验 的结果,我们知道,pls_rsh_agent "/usr/bin/python -m horovod.spark.driver.mpirun_rsh xxxxx" 这部分在远端上其实不会有实际效果远端 orted 转而会继续运行传过来的 mpirun_exec_fn

如果哪位朋友对 MPI 有深入了解,还望赐教

4.4.5 mpirun_exec_fn

代码位于:horovod/spark/task/mpirun_exec_fn.py。

就是调用到了task_exec。

def main(driver_addresses, settings):
    # prepend HOROVOD_SPARK_PYTHONPATH to PYTHONPATH
    if 'HOROVOD_SPARK_PYTHONPATH' in os.environ:
        ppath = os.environ['HOROVOD_SPARK_PYTHONPATH']

        # add injected HOROVOD_SPARK_PYTHONPATH to sys.path
        for p in reversed(ppath.split(os.pathsep)):
            sys.path.insert(1, p)  # don't put it in front which is usually .

        if 'PYTHONPATH' in os.environ:
            ppath = os.pathsep.join([ppath, os.environ['PYTHONPATH']])
        os.environ['PYTHONPATH'] = ppath

    # change current working dir to where the Spark worker runs
    # because orted runs this script where mpirun was executed
    # this env var is injected by the Spark task service
    work_dir = os.environ.get('HOROVOD_SPARK_WORK_DIR')
    if work_dir:
        os.chdir(work_dir)

    task_exec(driver_addresses, settings, 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_LOCAL_RANK')

0x05 第五阶段 : 运行用户代码

5.1 task_exec

task_exec 就是用来运行用户代码。

可以看到,是从 Driver 之中取出之前存储的代码,然后运行。

def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(),))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)

5.2 获取训练代码

在MapReduce之中,则是把Jar包(二进制库)分发到各个节点,然后各个节点执行jar包之中相应的代码。其实这样很不方便。

Spark提出了函数序列化功能,可以很好的解决这个问题,这是Spark对分布式编程的一个贡献。Spark系统会把你写的那些自定义函数(你的业务功能)自动序列化到各个节点去执行。函数序列化发送功能给Spark带来的另外好处是:用户可以使用spark-shell在命令行直接写分布式代码,实时操作,实时得到结果。

比如,初始化/协调等工作是在Driver程序中进行,但是代码实际执行是在Worker节点中的Executor中进行。当Executor端执行时需要用到Driver端封装的class对象时,Driver端就需要把Driver端的class对象通过序列化传输到Executor端,这个class对象则需要实现Serializable序列化方法。

Horovod on spark 这里就是直接传输 python 训练代码原始文本,因为 pyton 的脚本特性,所以可以直接运行代码原始文本

获取训练代码的函数如下,在 SparkDriverClient 类之中就是给 Driver 发送 CodeRequest 请求:

def code(self):
    resp = self._send(CodeRequest())
    return resp.fn, resp.args, resp.kwargs

在 SparkDriverService 之中,收到 CodeRequest 请求之后,会进行处理。

if isinstance(req, CodeRequest):
    return CodeResponse(self._fn, self._args, self._kwargs)

就是把 SparkDriverService 之前存储的训练代码 _fn 以及其参数一起发给 SparkTaskService。

class CodeResponse(object):
    def __init__(self, fn, args, kwargs):
        self.fn = fn
        """Function."""

        self.args = args
        """Function args."""

        self.kwargs = kwargs
        """Function kwargs."""

最终逻辑大致如下:

+---------------------------------+                     +---------------------------------+
| Horovod Main thread             |                     | Spark Executor                  |
|                                 |                     |                                 |
|                                 |                     |                                 |
|  +-------------------------+    |       1 register    |        +----------------------+ |
|  |     SparkDriverService  | <---------------------------------+  SparkTaskService    | |
|  |                         |    |                     |        |                      | |
|  |                         |    |      2 notify start |        |                      | |
|  |                         | +-------------------------------> |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    |                     |        |                      | |
|  |                         |    | 3 RunCommandRequest |        |                      | |
|  |                         | +--------------------------------------> orted mpirun_rsh| |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 4           | |
|  |                         |    |                     |        |        |             | |
|  |                         |    |                     |        |        v             | |
|  |                         |    |                     |        |      task_exec       | |
|  |                         |    |                     |        |        +             | |
|  |                         |    |                     |        |        | 5           | |
|  |                         |    |                     +        |        |             | |
|  |                         |    |6 set_local_rank_to_rank      |        v             | |
|  |                         | +------------------------+---------> SparkTaskClient     | |
|  |                         |    |                     |        |                      | |
|  |                         |    |    7 code()         |        |                      | |
|  |                         | +---------------------------------------> 8 fn()         | |
|  |                         |    |                     |        |                      | |
|  +-------------------------+    |                     |        +----------------------+ |
+---------------------------------+                     +---------------------------------+

手机如下:

至此,spark on MPI 分析结束,我们下文介绍 spark on GLOO。

0xEE 个人信息

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。

在这里插入图片描述

0xFF

mpirun,mpiexec和mpiexec.hydra有什么区别和关系?

posted @ 2021-07-05 06:25  罗西的思考  阅读(864)  评论(0编辑  收藏  举报