You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.0 KiB
66 lines
2.0 KiB
"""
|
|
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
|
|
|
|
`torch.distributed.launch` is a module that spawns up multiple distributed
|
|
training processes on each of the training nodes.
|
|
|
|
"""
|
|
|
|
|
|
import sys
|
|
import subprocess
|
|
import importlib
|
|
import os
|
|
from argparse import ArgumentParser, REMAINDER
|
|
from typing import Optional, IO
|
|
|
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
|
|
|
|
|
def parse_args():
|
|
"""
|
|
Helper function parsing the command line options
|
|
@retval ArgumentParser
|
|
"""
|
|
parser = ArgumentParser(
|
|
description="PyTorch distributed training launch helper utility"
|
|
"that will spawn up multiple distributed processes")
|
|
|
|
# Optional arguments for the launch helper
|
|
parser.add_argument("--num-devices", type=int, default=1,
|
|
help="The number of XLA devices to use for distributed training")
|
|
|
|
# positional
|
|
parser.add_argument(
|
|
"script", type=str,
|
|
help="The full path to the single device training script to be launched"
|
|
"in parallel, followed by all the arguments for the training script")
|
|
|
|
# rest from the training program
|
|
parser.add_argument('script_args', nargs=REMAINDER)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# set PyTorch distributed related environmental variables
|
|
# current_env = os.environ.copy()
|
|
# current_env["MASTER_ADDR"] = args.master_addr
|
|
# current_env["MASTER_PORT"] = str(args.master_port)
|
|
# current_env["WORLD_SIZE"] = str(dist_world_size)
|
|
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
|
|
# current_env["OMP_NUM_THREADS"] = str(1)
|
|
|
|
script_abs = os.path.abspath(args.script)
|
|
script_base, script_rel = os.path.split(script_abs)
|
|
sys.path.append(script_base)
|
|
mod = importlib.import_module(os.path.splitext(script_rel)[0])
|
|
|
|
sys.argv = [args.script] + args.script_args
|
|
|
|
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |