[源码解析] 深度学习分布式训练框架 horovod (11) --- on spark --- GLOO 方案

·  阅读 283

0x00 摘要

Horovod 是Uber于2017年发布的一个易于使用的高性能的分布式训练框架,在业界得到了广泛应用。

本系列将通过源码分析来带领大家了解 Horovod。本文是系列第十一篇,看看horovod 如何运行在 spark 之上(GLOO实现)。

Horovod on Spark 具体有两种底层实现:MPI,GLOO。因为篇幅所限,本文介绍 GLOO 实现。为了单篇可以成文,所以本文和上文有部分重复,望谅解。

本系列其他文章如下:

[源码解析] 深度学习分布式训练框架 Horovod (1) --- 基础知识

[源码解析] 深度学习分布式训练框架 horovod (2) --- 从使用者角度切入

[源码解析] 深度学习分布式训练框架 horovod (3) --- Horovodrun背后做了什么

[源码解析] 深度学习分布式训练框架 horovod (4) --- 网络基础 & Driver

[源码解析] 深度学习分布式训练框架 horovod (5) --- 融合框架

[源码解析] 深度学习分布式训练框架 horovod (6) --- 后台架构

[源码解析] 深度学习分布式训练框架 horovod (6) --- 线程实现

[源码解析] 深度学习分布式训练框架 horovod (7) --- DistributedOptimizer

[源码解析] 深度学习分布式训练框架 horovod (8) --- on spark

[源码解析] 深度学习分布式训练框架 horovod (9) --- 启动 on spark

[源码解析] 深度学习分布式训练框架 horovod (10) --- run on spark

0x01 回顾

1.1 总体序列图

我们首先要回顾下 Horovod on Spark 的总体序列图,需要注意的是:这个总体序列图之中,从mpi_run开始,是 mpi 相关的实现,但本文是Gloo方案,所以会从 mpi_run 那里开始不同。

img

1.2 总体逻辑

总体来说,Horovod on Spark 的总体逻辑分为以下阶段:

  • 启动 SparkDriverService 服务,利用 _make_spark_thread 启动 Spark task,然后 horovod 会等待启动结束;
  • 多线程在 spark executor 之中启动 spark task,每个task之中运行一个 SparkTaskService,SparkTaskService 会向 hovorod 主进程中的 SparkDriverTask 进行注册,并且等待下一步运行启动的指令;
  • Horovod 收到所有 task 结束的信息之后,通知各个 task,进入下一阶段;
  • Horovod 调用 mpi_run (又利用到 mpirun_rsh.py)在每一个 spark executor 上启动 orted(这里是通过 SparkTaskService 来启动 orted),以启动 MPI cluster;
  • orted 在每一个 spark executor 之上运行训练代码;

前文已经分析了前面三个阶段,本文继续后面两个阶段的分析。

0x02 第四阶段 : 启动 Job

下面我们看看第四阶段,就是如何运行训练 job。

2.1 GLOO VS MPI

本文的问题点就是:Gloo 与 MPI 实现有何不同。

2.1.1 MPI 麻烦之处

MPI 麻烦之处是因为:

  • 通常 MPI 会通过 SSH 来连接 hosts,但是这种方式无法在 Spark Executor 之中启动 Python funtion。
  • Orted 需要运行在 Spark Executor 之中,但是 mpirun 在启动时候,没办法知道 Spark Executor 的 IP : PORT 这个组合,所以没法直接启动。
  • 因此 MPI 使用RPC 来启动用户代码:
    • 通过 SparkDriverService 和 SparkTaskService 等交互才可以知道这个 IP : PORT 组合信息,即,在 Spark Executor 之中启动 SparkTaskService ,然后把 SparkTaskService 的 IP : PORT 注册到 Horovod 主进程的 SparkDriverService 之中。
    • 使用 horovod.spark.driver.mpirun_rsh 来连接每个 Executor,然后 “remote shell” 到这些 executors 之中。
    • 直接使用 SparkTaskService 来启动 orted。

2.1.2 Gloo关键点

我们看看Gloo的关键点,在普通模式下,Gloo方案会:

  • 创建一个带有 KVStore 的 RendezvousServer,driver 会将参与通信的 worker 的 ip 等信息存入 KVstore 中。这些信息被用于帮助 worker 调用 gloo 构造 AllReduce 通信环。
  • 然后 worker 就可以调用 gloo 来访问 RendezvousServer 构造通信环了。

在 Horovod on Spark 之中,关键点就是:

  • 如何构造RendezvousServer,RendezvousServer如何知道Executor(或者类似实体)的 ip:port?
  • Executor上的 SparkTaskService 如何与 RendezvousServer 沟通,从而知道自己和邻居的网络信息?

让我们从代码中寻求下答案。

2.2 回顾启动过程

我们首先要回顾下之前的启动过程。

Horovod.spark.run 的逻辑是:

  • 处理各种配置,比如timeout,nice...;
  • 获取 spark 信息,比如从 pyspark 之中获取SparkContext;
  • 构建驱动 SparkDriverService(Spark driver service);
  • 利用 _make_spark_thread 来启动 spark executor(以及在每一个 spark executor 之中启动一个SparkTaskService),这样就构建了 cluster;
  • 每个 SparkTaskService 会通过 driver_service.SparkDriverClient.register_task 来向 horovod 中的 Driver 注册这就是关键之处,通过这里 RendezvousServer 就知道了 SparkTaskService 的 IP :PORT
  • 利用 _notify_and_register_task_addresses 等待所有 spark task 都结束;
  • 利用 _launch_job 启动训练;
  • 利用 spark_thread.join 来收集训练结果;

以上关键点是:SparkTaskService 本身内部有一个 http server,会把自己的IP:PORT 信息注册到Driver之中。

2.3 _launch_job

我们从_launch_job 开始分析。

_launch_job 很简单:

  • 首先 driver.get_common_interfaces 获取网络路由信息,这个网络路由信息就将被 RendezvousServer 记录下来,最终将被 Executor上的 SparkTaskService 利用;
  • 其次 调用 run_contoller 来启动 job;
def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    nics = driver.get_common_interfaces()
    # 在 gloo_run 调用时候传输网络路由信息。
    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env, stdout, stderr),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)
复制代码

2.3 获取路由信息

Driver 的 get_common_interfaces 与普通模式下的 get_common_interfaces 不同。因为此时,Spark Executor 之中的 SparkTaskService 的信息已经保存在 Driver 之中,直接获取即可。

def get_common_interfaces(self):
    if self._nics is not None:
        return self._nics

    nics = None
    if len(self._task_addresses_for_tasks) > 0:
        # in Elastic Horovod on Spark with auto-scaling
        # keys in task_addresses are in range(max_np or proc_num)
        # but not all keys may exist, so we don't do for index in range(proc_num)
        indices = list(self._task_addresses_for_tasks.keys())
        nics = set(self._task_addresses_for_tasks[indices[0]].keys())
        for index in indices[1:]:
            nics.intersection_update(self._task_addresses_for_tasks[index].keys())

    return nics
复制代码

2.4 run_controller

就是依据配置和编译情况来进行处理,选择 gloo,js,还是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run()
    elif use_mpi:
        mpi_run()
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run()
        elif gloo_built(verbose=verbose):
            gloo_run()

复制代码

所以我们开始启动 job,下面就 GLOO进行分析。

0x03 Gloo 实现

相比 MPI,Gloo 这部分就比较清晰了。

3.1 gloo_run

回到 2.3 run_controller

就是依据配置和编译情况来进行处理,选择 gloo,js,还是 mpi。

def run_controller(use_gloo, gloo_run, use_mpi, mpi_run, use_jsrun, js_run, verbosity):
    if use_gloo:
        gloo_run() # 本文调用到这里
    elif use_mpi:
        mpi_run() # mpi会调用到这里
    elif use_jsrun:
        js_run()
    else:
        if mpi_built(verbose=verbose):
            if lsf.LSFUtils.using_lsf() and is_jsrun_installed():
                js_run()
            else:
                mpi_run() # mpi会调用到这里
        elif gloo_built(verbose=verbose):
            gloo_run() # 本文调用到这里

复制代码

如果是配置了gloo,则我们用到了 gloo_run:

def gloo_run(settings, nics, driver, env, stdout=None, stderr=None):
    """
    Run distributed gloo jobs.

    :param settings: Settings for running the distributed jobs.
                     Note: settings.num_proc and settings.hosts must not be None.
    :param nics: Interfaces to use by gloo.
    :param driver: The Spark driver service that tasks are connected to.
    :param env: Environment dictionary to use for running gloo jobs.  Can be None.
    :param stdout: Horovod stdout is redirected to this stream.
    :param stderr: Horovod stderr is redirected to this stream.
    """
    if env is None:
        env = {}

    # we don't want the key to be serialized along with settings from here on
    key = settings.key
    settings.key = None

    # Each thread will use SparkTaskClient to launch the job on each remote host. If an
    # error occurs in one thread, entire process will be terminated. Otherwise,
    # threads will keep running and ssh session.
    iface = list(nics)[0]
    server_ip = driver.addresses()[iface][0][0]
    # 这里构建了需要执行的命令
    command = (sys.executable,
               '-m', 'horovod.spark.task.gloo_exec_fn', # 这个就是在task里面运行的代码
               codec.dumps_base64(driver.addresses()),
               codec.dumps_base64(settings))

    # 可以认为_exec_command_fn这里是一种执行命令的能力
    exec_command = _exec_command_fn(driver, key, settings, env,
                                    stdout, stderr, settings.prefix_output_with_timestamp)
    # 这里传入了路由信息
    launch_gloo(command, exec_command, settings, nics, {}, server_ip)
复制代码

需要注意的是,这里的 _exec_command_fn 如下,可以认为_exec_command_fn这里是一种执行命令的能力:

def _exec_command_fn(driver, key, settings, env, stdout, stderr, prefix_output_with_timestamp):
    def _exec_command(command, slot_info, events):
        host = slot_info.hostname #host名字
        local_rank = slot_info.local_rank # 本地rank
        verbose = settings.verbose
        # 用rsh封装的运行能力
        result = rsh(driver.addresses(), key, host, command, env, local_rank, verbose,
                     stdout, stderr, prefix_output_with_timestamp, False, events)
        return result, time.time()
    return _exec_command
复制代码

即调用了 from horovod.spark.driver.rsh import rsh。这里是关键。

3.2 launch_gloo

这里主要是:

  • 首先,要注意,参数中,
    • command 大致为:'python','-m','horovod.spark.task.gloo_exec_fn';
    • exec_command 大致为:rsh xxxx。因为exec_command可以认为是一种利用rsh执行command的能力,所以这里的xxx对应本文就是 “python -m horovod.spark.task.gloo_exec_fn”;
  • 建立了 RendezvousServer;
  • 构建了 slot_info_to_command,这里指定了在哪一个slot上面运行;
  • 调用 execute_function_multithreaded 来使用多线程来运行命令;
def launch_gloo(command, exec_command, settings, nics, env, server_ip):
    """
    Launches the given command multiple times using gloo.
    Each command is launched via exec_command.

    :param command: command to launch
    :param exec_command: means to execute a single command
    :param settings: settings for the distribution
    :param nics: common interfaces
    :param env: environment to use
    :param server_ip: ip to use for rendezvous server
    """
		......
    
    # start global rendezvous server and get port that it is listening on
    # 建立 RendezvousServer,这个会被底层 Gloo C++ 环境使用到
    rendezvous = RendezvousServer(settings.verbose)

    # allocate processes into slots
    # 来根据host进行分配slot,就是horovod的哪个rank应该在哪个host上的哪个slot之上运行
    hosts = parse_hosts(settings.hosts)
    host_alloc_plan = get_host_assignments(hosts, settings.num_proc)

    # start global rendezvous server and get port that it is listening on
    global_rendezv_port = rendezvous.start()
    rendezvous.init(host_alloc_plan)
    # 获取到可执行命令
    run_command = get_run_command(command, server_ip, nics, global_rendezv_port)

    # 得到在slot之上可执行的 slot command
    slot_info_to_command = _slot_info_to_command_fn(run_command, env)
    event = register_shutdown_event()
    # 依据 slot_info_to_command_fn 构建 args_list,这个 list 之中,每一个arg就是一个 slot command
    args_list = [[slot_info_to_command(slot_info), slot_info, [event]]
                 for slot_info in host_alloc_plan]

    # If an error occurs in one thread, entire process will be terminated.
    # Otherwise, threads will keep running.
    # 多线程执行,在每一个 exec_command 之上执行每一个 arg(slot command),args_list 包括 HOROVOD_GLOO_RENDEZVOUS_ADDR 等信息
    res = threads.execute_function_multithreaded(exec_command,
                                                 args_list,
                                                 block_until_all_done=True)

    ......
复制代码

具体如下图所示:

               launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'
                    +       exec_command = rsh xxxx)
                    |
                    |
                    |
                    |
                    |
                    v
               RendezvousServer
                    +
                    |
                    |   get_run_command
                    |
                    |
                    v
 run_command = HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222
               HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......
               python +m horovod.spark.task.gloo_exec_fn
 exec_command = rsh xxxx

                    +
                    |
                    |   _slot_info_to_command_fn
                    |
                    v

slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......
                       HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222
                       HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......
                       python -m horovod.spark.task.gloo_exec_fn
        exec_command = rsh xxxx
                    +
                    |
                    |
                    |
                    v
               threads.execute_function_multithreaded
                    +
                    |
                    |
                    v
复制代码

手机如下:

img

3.2.1 get_run_command

launch_gloo 代码之中所用到的get_run_command十分关键,它会调用 create_run_env_vars 得到gloo需要信息,并据此构建 run_command,其格式如下:

HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo HOROVOD_CONTROLLER=gloo python train.py


复制代码

代码如下:

def get_run_command(command, server_ip, nics, port, elastic=False):
    env_vars = create_run_env_vars(server_ip, nics, port, elastic)
    env_string = " ".join(
        [f"{k}={str(v)}" for k, v in env_vars.items()])
    run_command = (
        '{env_string} '
        '{command}'  # expect a lot of environment variables
        .format(env_string=env_string,
                command=' '.join(quote(par) for par in command)))
    return run_command


复制代码

3.2.2 create_run_env_vars

create_run_env_vars 函数会把 gloo 运行的相关信息构建出来,这些信息最后会传给 Spark Executor。

def create_run_env_vars(server_ip, nics, port, elastic=False):
    run_envs = {
        'HOROVOD_GLOO_RENDEZVOUS_ADDR': server_ip,
        'HOROVOD_GLOO_RENDEZVOUS_PORT': port,
        'HOROVOD_CONTROLLER': "gloo",
        'HOROVOD_CPU_OPERATIONS': "gloo",
        'HOROVOD_GLOO_IFACE': list(nics)[0],   # TODO: add multiple ifaces in future
        'NCCL_SOCKET_IFNAME': ','.join(nics), # 这里就是构建环需要的网络路由信息
    }
    if elastic:
        run_envs["HOROVOD_ELASTIC"] = "1"
    return run_envs


复制代码

3.3 rsh

在 execute_function_multithreaded 之中,调用了 rsh,并最终与 Spark Executor 交互。

具体会:

  • 获取到 driver handle;
  • 利用driver handle调用 SparkDriverClient 获取 task 相关信息;
  • 获取 task handle;
  • 调用 SparkTaskClient 的 run_command 方法 来进行发送命令给 Spark Executor,这里的参数 command 内容大致为 “'python -m horovod.spark.task.gloo_exec_fn”;
  • 等待运行结果;

在调用 rsh 时候,command 会包括 类似 HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo 等信息,这样 SparkDriverService 就知道如何构建 Ring 路由了。

def rsh(driver_addresses, key, host_hash, command, env, local_rank, verbose,
        stdout=None, stderr=None, prefix_output_with_timestamp=False,
        background=True, events=None):
    """
    Method to run a command remotely given a host hash, local rank and driver addresses.

    This method connects to the SparkDriverService running on the Spark driver,
    retrieves all information required to connect to the task with given local rank
    of that host hash and invoke the command there.

    The method returns immediately after launching the command if background is True (default).
    When background is set to False, this method waits for command termination and returns
    command's result. If there is an exception while waiting for the result (i.e. connection reset)
    it returns -1.

    :param driver_addresses: driver's addresses
    :param key: used for encryption of parameters passed across the hosts
    :param host_hash: host hash to connect to
    :param command: command and arguments to invoke
    :param env: environment to use
    :param local_rank: local rank on the host of task to run the command in
    :param verbose: verbosity level
    :param stdout: Task stdout is redirected to this stream.
    :param stderr: Task stderr is redirected to this stream.
    :param prefix_output_with_timestamp: shows timestamp in stdout/stderr forwarding on the driver if True
    :param background: run command in background if True, returns command result otherwise
    :param events: events to abort the command, only if background is True
    :return exit code if background is False
    """
    if ':' in host_hash:
        raise Exception('Illegal host hash provided. Are you using Open MPI 4.0.0+?')

    # 获取到 driver handle    
    driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=verbose)
    # 利用配置确定是哪一个task来运行
    task_indices = driver_client.task_host_hash_indices(host_hash)
    task_index = task_indices[local_rank]
    task_addresses = driver_client.all_task_addresses(task_index)
    # 获取task handle
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=verbose)
    task_client.stream_command_output(stdout, stderr)
    # 要求task运行命令,command就是 python -m horovod.spark.task.gloo_exec_fn
    task_client.run_command(command, env,
                            capture_stdout=stdout is not None,
                            capture_stderr=stderr is not None,
                            prefix_output_with_timestamp=prefix_output_with_timestamp)

    if not background:
        events = events or []
        stop = threading.Event()
        for event in events:
            on_event(event, task_client.abort_command, stop=stop)
        try:
            exit_code = task_client.wait_for_command_exit_code()
            return exit_code
        except:
            traceback.print_exc()
            return -1
        finally:
            stop.set()


复制代码

所以,此时逻辑如下,最终在spark executor 运行python -m horovod.spark.task.gloo_exec_fn

                                                                                                          Horovod Job    +    Spark Host
                                                                                                                         |
SparkDriverService                           horovod.spark.run                                                           |                    SparkTaskService
         +                                        +                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                                   launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'       |                           |
         |                                        +       exec_command = rsh xxxx)                                       |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   RendezvousServer                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   get_run_command                                                    |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                    run_command = HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222       |                           |
         |                                  HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......                     |                           |
         |                                  python +m horovod.spark.task.gloo_exec_fn                                    |                           |
         |                    exec_command = rsh xxxx                                                                    |                           |
         |                                                                                                               |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   _slot_info_to_command_fn                                           |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                    slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......    |                           |
         |                                        HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 |                           |
         |                                        HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo ......               |                           |
         |                                        python +m horovod.spark.task.gloo_exec_fn                              |                           |
         |                            exec_command = rsh xxxx                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   threads.execute_function_multithreaded                                      |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                       rsh                                                                     |                           |
         |                                        +                                                                      |                           |
         |  <----------------------------------+  |                                                                      |                           |
         |      task_host_hash_indices            |                                                                      |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                     run_command(command, env)                        |    RunCommandRequest      |
         |      all_task_addresses                |                                                                      |                           |
         |                                        | +--------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |                           +
         |                                        |                                                                      |                      run command
         |                                        |                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         v                                        v                                                                      |                           |
                                                                                                                         +                           v



复制代码

手机如下:

img

3.4 gloo_exec_fn

注意,此时已经在 Spark Host 上的 Executor 中运行了。

gloo_exec_fn 就对应了前面 mpi版本的 mpirun_exec_fn

spark 在 Executor 上运行 horovod.spark.task.gloo_exec_fn

horovod.spark.task.gloo_exec_fn 内容如下:

from horovod.spark.task import task_exec
from horovod.runner.common.util import codec

def main(driver_addresses, settings):
    task_exec(driver_addresses, settings, 'HOROVOD_RANK', 'HOROVOD_LOCAL_RANK')

if __name__ == '__main__':
    if len(sys.argv) != 3:
        print('Usage: %s <driver addresses> <settings>' % sys.argv[0])
        sys.exit(1)
    main(codec.loads_base64(sys.argv[1]), codec.loads_base64(sys.argv[2]))


复制代码

0x04 第五阶段 : 运行用户代码

task_exec 函数就是运行用户代码进行训练。

task_exec 位于:horovod/spark/task/__init__.py

具体会:

  • 调用 SparkDriverClient 获取 task 相关信息;
  • 调用 SparkTaskClient 来进行获取 用户代码;
  • 执行用户代码等等。
def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(),))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses, key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index, task_addresses, key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)


复制代码

最终代码如下:

                                                                                                          Horovod Job    +    Spark Host
                                                                                                                         |
SparkDriverService                           horovod.spark.run                                                           |                    SparkTaskService
         +                                        +                                                                      |                           +
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                                   launch_gloo( command ='python','+m','horovod.spark.task.gloo_exec_fn'       |                           |
         |                                        +       exec_command = rsh xxxx)                                       |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   RendezvousServer                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   get_run_command                                                    |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                     run_command = rendevous_addr, rendevous_port python -m horovod.spark.task.gloo_exec_fn    |                           |
         |                    exec_command = rsh xxxx                                                                    |                           |
         |                                                                                                               |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |   _slot_info_to_command_fn                                           |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                                                                                               |                           |
         |                    slot_info_to_command = rank=0,local_rank=0,socket+ifname=eth0,cpu_operations=gloo......    |                           |
         |                                     rendevous_addr, rendevous_port python -m horovod.spark.task.gloo_exec_fn  |                           |
         |                            exec_command = rsh xxxx                                                            |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                   threads.execute_function_multithreaded                                      |                           |
         |                                        +                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        v                                                                      |                           |
         |                                       rsh                                                                     |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                                                                      |                           |
         |      task_host_hash_indices            |                                                                      |                           |
         |                                        |                                                                      |                           |
         |  <----------------------------------+  |                     run_command(command, env)                        |    RunCommandRequest      |
         |      all_task_addresses                |                                                                      |                           |
         |                                        | +--------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                      run command
         |                                        |                                                                      |                           +
         |                                        |                                                                      |      code()               |
         |  <-------------------------------------------------------------------------------------------------------------------------------------+  |
         |                                        |                                                                      |                           |
         |  +------------------------------------------------------------------------------------------------------------------------------------->  |
         |                                        |                                                                      |  code  of gloo_exec_fn    |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                     gloo_exec_fn
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                           |
         |                                        |                                                                      |                      task_exec
         v                                        |                                                                      |                           |
                                                  v                                                                      |                           |
                                                                                                                         +                           v



复制代码

手机如下:

img

0x05 总结

在普通模式下,Gloo方案会:

  • 创建一个带有 KVStore 的 RendezvousServer,driver 会将参与通信的 worker 的 ip 等信息存入 KVstore 中。
  • 然后 worker 就可以调用 gloo 来访问 RendezvousServer 构造通信环了。

在 Horovod on Spark via GLOO 之中,关键点就是:

  • 如何构造RendezvousServer,RendezvousServer如何知道Executor的 ip:port?
    • 答案为:
      • 在 Horovod 的 driver 之中,会创建RendezvousServer。
      • 在之前的初始化过程中,每个 SparkTaskService 会通过 driver_service.SparkDriverClient.register_task 来向 horovod 中的 Driver 注册这就是关键之处,通过这里 RendezvousServer 就可以知道 SparkTaskService 的 IP :PORT
  • Executor上的 SparkTaskService 如何与 RendezvousServer 沟通,从而知道自己和邻居的网络信息?
    • 答案为:
      • 在 execute_function_multithreaded 之中,调用了 rsh,并最终与 Spark Executor 交互。
      • 在调用 rsh 时候,会把类似 HOROVOD_GLOO_RENDEZVOUS_ADDR=1.1.1.1 HOROVOD_GLOO_RENDEZVOUS_PORT=2222 HOROVOD_CPU_OPERATIONS=gloo HOROVOD_GLOO_IFACE=lo 信息传递过去,此信息中包括了 RendezvousServer 的地址,这样 Spark Executor 中的 SparkTaskService 就知道了如何找到RendezvousServer,进而就会知道如何构建 ring。

至此,Horovod on spark解析完毕,从下一篇开始解析弹性训练。

0xEE 个人信息

★★★★★★关于生活和技术的思考★★★★★★

微信公众账号:罗西的思考

如果您想及时得到个人撰写文章的消息推送,或者想看看个人推荐的技术资料,敬请关注。