Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work

pull/1239/head
Ross Wightman 3 years ago
parent 12d9a6d4d2
commit 76de984a5f

@ -0,0 +1,66 @@
"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
"""
import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")
# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")
# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)
script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])
sys.argv = [args.script] + args.script_args
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
if __name__ == "__main__":
main()

@ -83,8 +83,10 @@ class DeviceEnvCuda(DeviceEnv):
return self._autocast
def wrap_distributed(self, *modules, **kwargs):
return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved

@ -10,6 +10,12 @@ except ImportError as e:
xm = None
_HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .device_env import DeviceEnv
@ -25,7 +31,6 @@ class DeviceEnvXla(DeviceEnv):
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
print(self._device)
self._local_rank = xm.get_local_ordinal(local_rank)
self._world_size = xm.xrt_world_size()
self._distributed = self._world_size > 1
@ -33,6 +38,7 @@ class DeviceEnvXla(DeviceEnv):
if self._distributed:
self._global_rank = xm.get_ordinal()
if amp:
assert xa is not None, 'XLA AMP is not present on this build'
self._autocast = xa.autocast
else:
self._autocast = suppress
@ -76,10 +82,12 @@ class DeviceEnvXla(DeviceEnv):
def wrap_distributed(self, *modules):
# NO-OP
return tuple([m for m in modules])
wrapped = [m for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module):
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def mark_step(self):
xm.mark_step()

@ -61,8 +61,8 @@ def summary_row_dict(results, index=None, index_name='epoch'):
return row_dict
if isinstance(next(iter(results.values())), dict):
# each key in results is a per-phase results dict, flatten by prefixing with phase name
for p, pr in results.keys():
assert isinstance(dict, pr)
for p, pr in results.items():
assert isinstance(pr, dict)
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
else:
row_dict.update(results)
@ -81,7 +81,7 @@ class SummaryCsv:
if self.needs_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
self.needs_header = False
dw.writerow(row_dict)
dw.writerow(row_dict)
def _add_kwargs(text_update, name_map=None, **kwargs):
@ -212,7 +212,6 @@ class Logger:
index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch')
"""
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer:
self.csv_writer.update(row_dict)

@ -4,12 +4,17 @@ import torch
try:
import torch_xla.core.xla_model as xm
import torch_xla.amp as xa
_HAS_XLA = True
except ImportError as e:
xm = None
_HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .updater import Updater
@ -26,6 +31,7 @@ class UpdaterXla(Updater):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
self.after_step_closure = True
if use_scaler:
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):

@ -40,11 +40,9 @@ def create_loader(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
@ -80,13 +78,14 @@ def create_loader(
dev_env = get_device()
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None:
collate_fn = fast_collate

@ -36,6 +36,7 @@ class AccuracyTopK(torch.nn.Module):
self.device = device
self.topk = topk
self.maxk = max(topk)
# FIXME handle distributed operation
# statistics / counts
self.reset()
@ -63,6 +64,7 @@ class AccuracyTopK(torch.nn.Module):
pass
def compute(self) -> Dict[str, torch.Tensor]:
# FIXME handle distributed reduction
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}

@ -8,6 +8,7 @@ class TensorAvg:
self.sum = None
self.count = None
self.reset()
# FIXME handle distributed operation
def reset(self):
self.sum = None
@ -32,6 +33,7 @@ class TensorEma:
self.init_zero = init_zero
self.val = None
self.reset()
# FIXME handle distributed operation
def reset(self):
self.val = None

@ -426,7 +426,6 @@ def main():
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=dev_env.is_distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
@ -441,7 +440,6 @@ def main():
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=dev_env.is_distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
@ -519,7 +517,7 @@ def main():
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if logger is not None:
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metric))
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None:
# save proper checkpoint with eval metric
@ -657,5 +655,9 @@ def evaluate(
return results
def _mp_entry(*args):
main()
if __name__ == '__main__':
main()

Loading…
Cancel
Save