Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work
parent
12d9a6d4d2
commit
76de984a5f
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
|
||||
|
||||
`torch.distributed.launch` is a module that spawns up multiple distributed
|
||||
training processes on each of the training nodes.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import importlib
|
||||
import os
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from typing import Optional, IO
|
||||
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Helper function parsing the command line options
|
||||
@retval ArgumentParser
|
||||
"""
|
||||
parser = ArgumentParser(
|
||||
description="PyTorch distributed training launch helper utility"
|
||||
"that will spawn up multiple distributed processes")
|
||||
|
||||
# Optional arguments for the launch helper
|
||||
parser.add_argument("--num-devices", type=int, default=1,
|
||||
help="The number of XLA devices to use for distributed training")
|
||||
|
||||
# positional
|
||||
parser.add_argument(
|
||||
"script", type=str,
|
||||
help="The full path to the single device training script to be launched"
|
||||
"in parallel, followed by all the arguments for the training script")
|
||||
|
||||
# rest from the training program
|
||||
parser.add_argument('script_args', nargs=REMAINDER)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# set PyTorch distributed related environmental variables
|
||||
# current_env = os.environ.copy()
|
||||
# current_env["MASTER_ADDR"] = args.master_addr
|
||||
# current_env["MASTER_PORT"] = str(args.master_port)
|
||||
# current_env["WORLD_SIZE"] = str(dist_world_size)
|
||||
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
|
||||
# current_env["OMP_NUM_THREADS"] = str(1)
|
||||
|
||||
script_abs = os.path.abspath(args.script)
|
||||
script_base, script_rel = os.path.split(script_abs)
|
||||
sys.path.append(script_base)
|
||||
mod = importlib.import_module(os.path.splitext(script_rel)[0])
|
||||
|
||||
sys.argv = [args.script] + args.script_args
|
||||
|
||||
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in new issue