Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work

pull/1239/head
Ross Wightman 3 years ago
parent 12d9a6d4d2
commit 76de984a5f

@ -0,0 +1,66 @@
"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
"""
import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")
# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")
# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)
script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])
sys.argv = [args.script] + args.script_args
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
if __name__ == "__main__":
main()

@ -83,8 +83,10 @@ class DeviceEnvCuda(DeviceEnv):
return self._autocast return self._autocast
def wrap_distributed(self, *modules, **kwargs): def wrap_distributed(self, *modules, **kwargs):
return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module): def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn? # FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved

@ -10,6 +10,12 @@ except ImportError as e:
xm = None xm = None
_HAS_XLA = False _HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .device_env import DeviceEnv from .device_env import DeviceEnv
@ -25,7 +31,6 @@ class DeviceEnvXla(DeviceEnv):
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False): def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type) self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
print(self._device)
self._local_rank = xm.get_local_ordinal(local_rank) self._local_rank = xm.get_local_ordinal(local_rank)
self._world_size = xm.xrt_world_size() self._world_size = xm.xrt_world_size()
self._distributed = self._world_size > 1 self._distributed = self._world_size > 1
@ -33,6 +38,7 @@ class DeviceEnvXla(DeviceEnv):
if self._distributed: if self._distributed:
self._global_rank = xm.get_ordinal() self._global_rank = xm.get_ordinal()
if amp: if amp:
assert xa is not None, 'XLA AMP is not present on this build'
self._autocast = xa.autocast self._autocast = xa.autocast
else: else:
self._autocast = suppress self._autocast = suppress
@ -76,10 +82,12 @@ class DeviceEnvXla(DeviceEnv):
def wrap_distributed(self, *modules): def wrap_distributed(self, *modules):
# NO-OP # NO-OP
return tuple([m for m in modules]) wrapped = [m for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module): def to_device(self, *modules: torch.nn.Module):
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def mark_step(self): def mark_step(self):
xm.mark_step() xm.mark_step()

@ -61,8 +61,8 @@ def summary_row_dict(results, index=None, index_name='epoch'):
return row_dict return row_dict
if isinstance(next(iter(results.values())), dict): if isinstance(next(iter(results.values())), dict):
# each key in results is a per-phase results dict, flatten by prefixing with phase name # each key in results is a per-phase results dict, flatten by prefixing with phase name
for p, pr in results.keys(): for p, pr in results.items():
assert isinstance(dict, pr) assert isinstance(pr, dict)
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()]) row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
else: else:
row_dict.update(results) row_dict.update(results)
@ -81,7 +81,7 @@ class SummaryCsv:
if self.needs_header: # first iteration (epoch == 1 can't be used) if self.needs_header: # first iteration (epoch == 1 can't be used)
dw.writeheader() dw.writeheader()
self.needs_header = False self.needs_header = False
dw.writerow(row_dict) dw.writerow(row_dict)
def _add_kwargs(text_update, name_map=None, **kwargs): def _add_kwargs(text_update, name_map=None, **kwargs):
@ -212,7 +212,6 @@ class Logger:
index: value for row index (typically epoch #) index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch') index_name: name for row index header (typically 'epoch')
""" """
row_dict = summary_row_dict(index=index, index_name=index_name, results=results) row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer: if self.csv_writer:
self.csv_writer.update(row_dict) self.csv_writer.update(row_dict)

@ -4,12 +4,17 @@ import torch
try: try:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.amp as xa
_HAS_XLA = True _HAS_XLA = True
except ImportError as e: except ImportError as e:
xm = None xm = None
_HAS_XLA = False _HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .updater import Updater from .updater import Updater
@ -26,6 +31,7 @@ class UpdaterXla(Updater):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
self.after_step_closure = True self.after_step_closure = True
if use_scaler: if use_scaler:
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs) self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False): def apply(self, loss: torch.Tensor, accumulate: bool = False):

@ -40,11 +40,9 @@ def create_loader(
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
num_workers=1, num_workers=1,
distributed=False,
crop_pct=None, crop_pct=None,
collate_fn=None, collate_fn=None,
pin_memory=False, pin_memory=False,
fp16=False,
tf_preprocessing=False, tf_preprocessing=False,
use_multi_epochs_loader=False, use_multi_epochs_loader=False,
persistent_workers=True, persistent_workers=True,
@ -80,13 +78,14 @@ def create_loader(
dev_env = get_device() dev_env = get_device()
sampler = None sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training: if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset) sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
else: else:
# This will add extra duplicate entries to result in equal num # This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results # of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset) sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None: if collate_fn is None:
collate_fn = fast_collate collate_fn = fast_collate

@ -36,6 +36,7 @@ class AccuracyTopK(torch.nn.Module):
self.device = device self.device = device
self.topk = topk self.topk = topk
self.maxk = max(topk) self.maxk = max(topk)
# FIXME handle distributed operation
# statistics / counts # statistics / counts
self.reset() self.reset()
@ -63,6 +64,7 @@ class AccuracyTopK(torch.nn.Module):
pass pass
def compute(self) -> Dict[str, torch.Tensor]: def compute(self) -> Dict[str, torch.Tensor]:
# FIXME handle distributed reduction
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}

@ -8,6 +8,7 @@ class TensorAvg:
self.sum = None self.sum = None
self.count = None self.count = None
self.reset() self.reset()
# FIXME handle distributed operation
def reset(self): def reset(self):
self.sum = None self.sum = None
@ -32,6 +33,7 @@ class TensorEma:
self.init_zero = init_zero self.init_zero = init_zero
self.val = None self.val = None
self.reset() self.reset()
# FIXME handle distributed operation
def reset(self): def reset(self):
self.val = None self.val = None

@ -426,7 +426,6 @@ def main():
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
distributed=dev_env.is_distributed,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader use_multi_epochs_loader=args.use_multi_epochs_loader
@ -441,7 +440,6 @@ def main():
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
distributed=dev_env.is_distributed,
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
) )
@ -519,7 +517,7 @@ def main():
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if logger is not None: if logger is not None:
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metric)) logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
@ -657,5 +655,9 @@ def evaluate(
return results return results
def _mp_entry(*args):
main()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

Loading…
Cancel
Save