From 76de984a5fd2894a21b6dae270548b09f2e3f602 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 21 Apr 2021 13:02:53 -0700 Subject: [PATCH] Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work --- launch_xla.py | 66 ++++++++++++++++++++++++++++++++++++ timm/bits/device_env_cuda.py | 6 ++-- timm/bits/device_env_xla.py | 14 ++++++-- timm/bits/logger.py | 7 ++-- timm/bits/updater_xla.py | 8 ++++- timm/data/loader.py | 9 +++-- timm/metrics/accuracy.py | 2 ++ timm/metrics/tensor_avg.py | 2 ++ train.py | 8 +++-- 9 files changed, 104 insertions(+), 18 deletions(-) create mode 100644 launch_xla.py diff --git a/launch_xla.py b/launch_xla.py new file mode 100644 index 00000000..9e60556c --- /dev/null +++ b/launch_xla.py @@ -0,0 +1,66 @@ +""" +Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla. + +`torch.distributed.launch` is a module that spawns up multiple distributed +training processes on each of the training nodes. + +""" + + +import sys +import subprocess +import importlib +import os +from argparse import ArgumentParser, REMAINDER +from typing import Optional, IO + +import torch_xla.distributed.xla_multiprocessing as xmp + + +def parse_args(): + """ + Helper function parsing the command line options + @retval ArgumentParser + """ + parser = ArgumentParser( + description="PyTorch distributed training launch helper utility" + "that will spawn up multiple distributed processes") + + # Optional arguments for the launch helper + parser.add_argument("--num-devices", type=int, default=1, + help="The number of XLA devices to use for distributed training") + + # positional + parser.add_argument( + "script", type=str, + help="The full path to the single device training script to be launched" + "in parallel, followed by all the arguments for the training script") + + # rest from the training program + parser.add_argument('script_args', nargs=REMAINDER) + return parser.parse_args() + + +def main(): + args = parse_args() + + # set PyTorch distributed related environmental variables + # current_env = os.environ.copy() + # current_env["MASTER_ADDR"] = args.master_addr + # current_env["MASTER_PORT"] = str(args.master_port) + # current_env["WORLD_SIZE"] = str(dist_world_size) + # if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1: + # current_env["OMP_NUM_THREADS"] = str(1) + + script_abs = os.path.abspath(args.script) + script_base, script_rel = os.path.split(script_abs) + sys.path.append(script_base) + mod = importlib.import_module(os.path.splitext(script_rel)[0]) + + sys.argv = [args.script] + args.script_args + + xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/timm/bits/device_env_cuda.py b/timm/bits/device_env_cuda.py index 29c4d8f6..d609bd2a 100644 --- a/timm/bits/device_env_cuda.py +++ b/timm/bits/device_env_cuda.py @@ -83,8 +83,10 @@ class DeviceEnvCuda(DeviceEnv): return self._autocast def wrap_distributed(self, *modules, **kwargs): - return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] + wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] + return wrapped[0] if len(wrapped) == 1 else wrapped def to_device(self, *modules: torch.nn.Module): # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? - return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] + moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] + return moved[0] if len(moved) == 1 else moved diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index 385b8626..18e0fd3b 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -10,6 +10,12 @@ except ImportError as e: xm = None _HAS_XLA = False +try: + # only the very latest XLA builds have AMP + import torch_xla.amp as xa +except ImportError as e: + xa = None + from .device_env import DeviceEnv @@ -25,7 +31,6 @@ class DeviceEnvXla(DeviceEnv): def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False): self._device = xm.xla_device(n=device_idx, devkind=xla_device_type) - print(self._device) self._local_rank = xm.get_local_ordinal(local_rank) self._world_size = xm.xrt_world_size() self._distributed = self._world_size > 1 @@ -33,6 +38,7 @@ class DeviceEnvXla(DeviceEnv): if self._distributed: self._global_rank = xm.get_ordinal() if amp: + assert xa is not None, 'XLA AMP is not present on this build' self._autocast = xa.autocast else: self._autocast = suppress @@ -76,10 +82,12 @@ class DeviceEnvXla(DeviceEnv): def wrap_distributed(self, *modules): # NO-OP - return tuple([m for m in modules]) + wrapped = [m for m in modules] + return wrapped[0] if len(wrapped) == 1 else wrapped def to_device(self, *modules: torch.nn.Module): - return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] + moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] + return moved[0] if len(moved) == 1 else moved def mark_step(self): xm.mark_step() diff --git a/timm/bits/logger.py b/timm/bits/logger.py index 2e2cd9da..d9ad41af 100644 --- a/timm/bits/logger.py +++ b/timm/bits/logger.py @@ -61,8 +61,8 @@ def summary_row_dict(results, index=None, index_name='epoch'): return row_dict if isinstance(next(iter(results.values())), dict): # each key in results is a per-phase results dict, flatten by prefixing with phase name - for p, pr in results.keys(): - assert isinstance(dict, pr) + for p, pr in results.items(): + assert isinstance(pr, dict) row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()]) else: row_dict.update(results) @@ -81,7 +81,7 @@ class SummaryCsv: if self.needs_header: # first iteration (epoch == 1 can't be used) dw.writeheader() self.needs_header = False - dw.writerow(row_dict) + dw.writerow(row_dict) def _add_kwargs(text_update, name_map=None, **kwargs): @@ -212,7 +212,6 @@ class Logger: index: value for row index (typically epoch #) index_name: name for row index header (typically 'epoch') """ - row_dict = summary_row_dict(index=index, index_name=index_name, results=results) if self.csv_writer: self.csv_writer.update(row_dict) diff --git a/timm/bits/updater_xla.py b/timm/bits/updater_xla.py index 0789f06f..25287ad9 100644 --- a/timm/bits/updater_xla.py +++ b/timm/bits/updater_xla.py @@ -4,12 +4,17 @@ import torch try: import torch_xla.core.xla_model as xm - import torch_xla.amp as xa _HAS_XLA = True except ImportError as e: xm = None _HAS_XLA = False +try: + # only the very latest XLA builds have AMP + import torch_xla.amp as xa +except ImportError as e: + xa = None + from .updater import Updater @@ -26,6 +31,7 @@ class UpdaterXla(Updater): super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) self.after_step_closure = True if use_scaler: + assert xa is not None, 'XLA AMP not present in this build' self.scaler = xa.GradScaler(**scaler_kwargs) def apply(self, loss: torch.Tensor, accumulate: bool = False): diff --git a/timm/data/loader.py b/timm/data/loader.py index 9b15eb02..45d40908 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -40,11 +40,9 @@ def create_loader( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=1, - distributed=False, crop_pct=None, collate_fn=None, pin_memory=False, - fp16=False, tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, @@ -80,13 +78,14 @@ def create_loader( dev_env = get_device() sampler = None - if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): + if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results - sampler = OrderedDistributedSampler(dataset) + sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) if collate_fn is None: collate_fn = fast_collate diff --git a/timm/metrics/accuracy.py b/timm/metrics/accuracy.py index 98aa59eb..b58a3781 100644 --- a/timm/metrics/accuracy.py +++ b/timm/metrics/accuracy.py @@ -36,6 +36,7 @@ class AccuracyTopK(torch.nn.Module): self.device = device self.topk = topk self.maxk = max(topk) + # FIXME handle distributed operation # statistics / counts self.reset() @@ -63,6 +64,7 @@ class AccuracyTopK(torch.nn.Module): pass def compute(self) -> Dict[str, torch.Tensor]: + # FIXME handle distributed reduction return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} diff --git a/timm/metrics/tensor_avg.py b/timm/metrics/tensor_avg.py index ac2fb6ed..c9a3489b 100644 --- a/timm/metrics/tensor_avg.py +++ b/timm/metrics/tensor_avg.py @@ -8,6 +8,7 @@ class TensorAvg: self.sum = None self.count = None self.reset() + # FIXME handle distributed operation def reset(self): self.sum = None @@ -32,6 +33,7 @@ class TensorEma: self.init_zero = init_zero self.val = None self.reset() + # FIXME handle distributed operation def reset(self): self.val = None diff --git a/train.py b/train.py index f105e525..de627929 100755 --- a/train.py +++ b/train.py @@ -426,7 +426,6 @@ def main(): mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, - distributed=dev_env.is_distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader @@ -441,7 +440,6 @@ def main(): mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, - distributed=dev_env.is_distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) @@ -519,7 +517,7 @@ def main(): lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if logger is not None: - logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metric)) + logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) if saver is not None: # save proper checkpoint with eval metric @@ -657,5 +655,9 @@ def evaluate( return results +def _mp_entry(*args): + main() + + if __name__ == '__main__': main()