Torch.distributed 使い方

pytorchの分散パッケージであるtorch.distributedのtutorialを自分なりにまとめた。
公式サイトを参考に、一般的な分散処理の手法について学んだ。

torch.distributedを使用すると、プロセスやクラスターマシンの計算の並列化を簡単に行うことができる。
まだわからないことが多いが、是非マスターしたい。
このページの素晴らしい参考元サイトは以下
https://pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce
https://pytorch.org/docs/stable/distributed.html#initialization

今回の環境: python3, pytorch version 1.4.0,

分散処理の用語確認

分散処理に詳しくないので、説明によく出てくる用語を確認する。

  • プロセス(process): プログラムの実行単位。カーネルによって管理される。
  • ジョブ(job):シェルを実行する作業単位で、シェルによって管理される。プロセスの集まり。
  • フォーク(fork):コピーすること。processをforkするは、processを複製することを表す。
  • メッセージパッシング(message passing): プロセス間の通信で使用される通信方式。
  • マスター(master):全体を制御するプロセスのこと。
  • スレイブ(slave):制御されるプロセスのこと。
  • 親プロセス:呼び出し元のプロセスのこと。
  • 子プロセス:他のプロセスから呼び出されたプロセスのこと。

分散環境のセットアップ

プロセスやクラスターマシンで計算を並列化するには、メッセージパッシングを利用して、各プロセスが、他のプロセスとデータ通信をする必要がある。torch.distributed は、torch.multiprocessing パッケージと異なり、プロセス同士が異なる通信バックエンドを使用することができる。以下の実装では1台のマシンを使用し、複数のプロセスをフォークする。

#!/usr/bin/env python
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def run(rank, size):
    """ Distributed function to be implemented later. """
    pass

def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 2
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

上のコードでは分散環境のセットアップを行っている。dist.init_process_groupでプロセスグループを初期化し、指定したrun関数を実行するための2つのプロセスを生成している。

init_process関数の解説

dist.init_process_groupによって、すべてのプロセスが同じIPアドレスとポートを使用することで、マスターを介して調整できるようになる。

環境変数の初期化

以下の環境変数を設定することで、情報の取得方法をカスタマイズできる。

  • MATER_PORT: rank0のマシンの空いているポート
  • MASTER_ADDR: rank0ノードのアドレス

dist.init_process_groupの解説

  • 役割
    • プロセスグループの初期化
    • 分散パッケージの初期化
  • 引数
    • backend:使用するバックエンドを指定
    • world_size:ジョブに参加しているプロセスの数
    • rank:現在のプロセスの番号(ランク)

どのバックエンドを使用すればよいのかについて、公式によると

  • 分散型のGPU学習にはNCCLバックエンド
  • 分散CPU学習にはGLOOバックエンド

を使うのが経験則的に良いとされている。

以降、run関数の中身によって通信の手法が変化する。

Point-to-Point 通信

あるプロセスから別のプロセスへのデータの転送は、point-to-point通信と呼ばれる。これらはsend, recvもしくは、isend, irecvを使用することで実現できる。

Blocking point-to-point 通信

def run(rank, size):
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        # process 1 へ tensorを送る
        dist.send(tensor=tensor, dst=1)
    else:
        # process 0 から tensorを受け取る
        dist.recv(tensor=tensor, src=0)
    print('Rank ', rank, ' has data ', tensor[0])
#出力
Rank  0  has data tensor(1.)
Rank  1  has data tensor(1.)

上記の例の流れは以下の通り。

  1. 両方のプロセスでゼロテンソルを初期化
  2. プロセス0がテンソルをインクリメント(1増やす)
  3. プロセス0からプロセス1に送信
  4. 両方のプロセスが1.0になって終了

両方のプロセスは、通信が完了するまで停止する。send, recvにブロッキングの働きがある。簡単な仕様は以下。

  • torch.distributed..send(tensor, dst)
    • 役割
      • テンソルを同期的に送信
    • 引数
      • tensor(Tensor):送信するTensor
      • dst(python: int): 宛先ランク
  • torch.distributed.recv(tensor, src)
    • 役割
      • 役割テンソルを同期的に受信
    • 引数
      • tensor(Tensor):送信するTensor
      • dst(python: int): ソースのランク。指定なしのときは任意のプロセスから受け取る。
    • 戻り値
      • 送信元のランク、グループに属していないときは-1

Non-blocking point-to-point 通信

def run(rank, size):
    tensor = torch.zeros(1)
    req = None
    if rank == 0:
        tensor += 1
        # process 1 へ tensorを送る
        req = dist.isend(tensor=tensor, dst=1)
        print('Rank 0 started sending')
    else:
        # process 0 から tensorを受け取る
        req = dist.irecv(tensor=tensor, src=0)
        print('Rank 1 started receiving')
    req.wait()
    print('Rank ', rank, ' has data ', tensor[0])
#出力
Rank 1 started receiving
Rank 0 started sending
Rank  1 has data  tensor(1.)
Rank  0 has data  tensor(1.)

non-blockingでは、実行を継続し、メソッドはworkオブジェクトを返し、すぐにwork()を選択できる。
non-blocking時には、送信及び受信テンソルに気をつける必要がある。データが他のプロセスにいつ通信されるかわからないので、req.wait()が完了する前に送信テンソルを変更したり、受信テンソルにアクセスしてはいけない。言い換えると、

  • dist.isend()のあとにtensorを書き込むと、未定義の動作になる(予測不能の動作)。
  • dist.irecv()のあとにtensorを読み込むと、未定義の動作になる。

ただし、req.wait()が実行されたあとは、通信が行われ、保存された値tensor[0]が1.0であることが保証される。

  • torch.distributed..isend(tensor, dst)
    • 役割
      • テンソルを非同期的に送信
    • 引数
      • tensor(Tensor):送信するTensor
      • dst(python: int): 宛先ランク
    • 戻り値
      • 分散リクエストオブジェクト。グループに属していないときはなし。
  • torch.distributed.irecv(tensor, src)
    • 役割
      • 役割テンソルを非同期で受信
    • 引数
      • tensor(Tensor):送信するTensor
      • dst(python: int): ソースのランク。指定なしのときは任意のプロセスから受け取る。
    • 戻り値
      • 分散リクエストオブジェクト。グループに属していないときはなし。

point-to-point通信はプロセスの通信をきめ細かく制御したい場合に役立つ。

集団通信(Collective Communication)

def run(rank, size):
    """ Simple point-to-point communication. """
    group = dist.new_group([0, 1])
    tensor = torch.ones(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
    print('Rank ', rank, ' has data ', tensor[0])
#出力
Rank  1  has data  tensor(2.)
Rank  0  has data  tensor(2.)

集合通信では、グループ内のすべてのプロセスとの通信パターンが可能になる。通信パターンについては公式の図がわかりやすい。
グループは、dist.new_group(group)で作成できる。デフォルトではworldとも呼ばれるすべてのプロセスで実行される。
例えばすべてのプロセスですべてのテンソルの合計を取得するには、dist.all_reduce(tensor, op, group) を使う。
上のコードでは、すべてのプロセスですべてのテンソルの合計を取得した。ReduceOp.SUMの部分を他のオプションに変えれば合計だけではなく、他の処理も可能。例(PRODUCT, MAX, MIN)

次回は実際にtorch.distributedを使用して深層学習モデルのtrainを行う。