Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work
parent
12d9a6d4d2
commit
76de984a5f
@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
|
||||||
|
|
||||||
|
`torch.distributed.launch` is a module that spawns up multiple distributed
|
||||||
|
training processes on each of the training nodes.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
from argparse import ArgumentParser, REMAINDER
|
||||||
|
from typing import Optional, IO
|
||||||
|
|
||||||
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""
|
||||||
|
Helper function parsing the command line options
|
||||||
|
@retval ArgumentParser
|
||||||
|
"""
|
||||||
|
parser = ArgumentParser(
|
||||||
|
description="PyTorch distributed training launch helper utility"
|
||||||
|
"that will spawn up multiple distributed processes")
|
||||||
|
|
||||||
|
# Optional arguments for the launch helper
|
||||||
|
parser.add_argument("--num-devices", type=int, default=1,
|
||||||
|
help="The number of XLA devices to use for distributed training")
|
||||||
|
|
||||||
|
# positional
|
||||||
|
parser.add_argument(
|
||||||
|
"script", type=str,
|
||||||
|
help="The full path to the single device training script to be launched"
|
||||||
|
"in parallel, followed by all the arguments for the training script")
|
||||||
|
|
||||||
|
# rest from the training program
|
||||||
|
parser.add_argument('script_args', nargs=REMAINDER)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# set PyTorch distributed related environmental variables
|
||||||
|
# current_env = os.environ.copy()
|
||||||
|
# current_env["MASTER_ADDR"] = args.master_addr
|
||||||
|
# current_env["MASTER_PORT"] = str(args.master_port)
|
||||||
|
# current_env["WORLD_SIZE"] = str(dist_world_size)
|
||||||
|
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
|
||||||
|
# current_env["OMP_NUM_THREADS"] = str(1)
|
||||||
|
|
||||||
|
script_abs = os.path.abspath(args.script)
|
||||||
|
script_base, script_rel = os.path.split(script_abs)
|
||||||
|
sys.path.append(script_base)
|
||||||
|
mod = importlib.import_module(os.path.splitext(script_rel)[0])
|
||||||
|
|
||||||
|
sys.argv = [args.script] + args.script_args
|
||||||
|
|
||||||
|
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in new issue