""" Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla. `torch.distributed.launch` is a module that spawns up multiple distributed training processes on each of the training nodes. """ import sys import subprocess import importlib import os from argparse import ArgumentParser, REMAINDER from typing import Optional, IO import torch_xla.distributed.xla_multiprocessing as xmp def parse_args(): """ Helper function parsing the command line options @retval ArgumentParser """ parser = ArgumentParser( description="PyTorch distributed training launch helper utility" "that will spawn up multiple distributed processes") # Optional arguments for the launch helper parser.add_argument("--num-devices", type=int, default=1, help="The number of XLA devices to use for distributed training") # positional parser.add_argument( "script", type=str, help="The full path to the single device training script to be launched" "in parallel, followed by all the arguments for the training script") # rest from the training program parser.add_argument('script_args', nargs=REMAINDER) return parser.parse_args() def main(): args = parse_args() # set PyTorch distributed related environmental variables # current_env = os.environ.copy() # current_env["MASTER_ADDR"] = args.master_addr # current_env["MASTER_PORT"] = str(args.master_port) # current_env["WORLD_SIZE"] = str(dist_world_size) # if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1: # current_env["OMP_NUM_THREADS"] = str(1) script_abs = os.path.abspath(args.script) script_base, script_rel = os.path.split(script_abs) sys.path.append(script_base) mod = importlib.import_module(os.path.splitext(script_rel)[0]) sys.argv = [args.script] + args.script_args xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices) if __name__ == "__main__": main()