You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/launch_xla.py

66 lines
2.0 KiB

"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
"""
import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")
# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")
# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)
script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])
sys.argv = [args.script] + args.script_args
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
if __name__ == "__main__":
main()