diff --git a/timm/bits/device_env.py b/timm/bits/device_env.py index 7307823e..bac9b0ab 100644 --- a/timm/bits/device_env.py +++ b/timm/bits/device_env.py @@ -90,8 +90,6 @@ class DeviceEnv: pass # NO-OP for non-XLA devices def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False): - print(len(tensor), type(tensor)) - print(tensor.shape) dist.all_reduce(tensor, op=op) if average: tensor.div_(self.world_size) diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index cc9ea3dd..71d350fd 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -23,12 +23,12 @@ from .device_env import DeviceEnv, DeviceEnvType, TensorList _PT_TO_XM_OP = { - ReduceOp.SUM: 'sum', - ReduceOp.PRODUCT: 'prod', - ReduceOp.MIN: 'min', - ReduceOp.MAX: 'max', - ReduceOp.BAND: 'and', - ReduceOp.BOR: 'or', + ReduceOp.SUM: xm.REDUCE_SUM, + ReduceOp.PRODUCT: xm.REDUCE_MUL, + ReduceOp.MIN: xm.REDUCE_MIN, + ReduceOp.MAX: xm.REDUCE_MAX, + ReduceOp.BAND: xm.REDUCE_AND, + ReduceOp.BOR: xm.REDUCE_OR, } @@ -77,20 +77,16 @@ class DeviceEnvXla(DeviceEnv): def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False): assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed op = _PT_TO_XM_OP[op] - scale = 1.0 - if average: - scale /= self.world_size + scale = 1.0 / self.world_size if average else 1.0 return xm.all_reduce(op, tensor, scale=scale) def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False): op = _PT_TO_XM_OP[op] - scale = 1.0 + scale = 1.0 / self.world_size if average else 1.0 wrapped = False if isinstance(tensor, torch.Tensor): tensor = [tensor] # bare tensors are not operated on in-place wrapped = True - if average: - scale /= self.world_size xm.all_reduce(op, tensor, scale=scale) if wrapped: tensor = tensor[0] diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py index 992546a7..3884958b 100644 --- a/timm/bits/train_setup.py +++ b/timm/bits/train_setup.py @@ -89,6 +89,7 @@ def setup_model_and_optimizer( train_state = TrainState(model=model, updater=updater, model_ema=model_ema) if resume_path: + # FIXME this is not implemented yet, do a hack job before proper TrainState serialization? resume_train_checkpoint( train_state, resume_path, diff --git a/train.py b/train.py index 3f18c8e5..95f5cb7e 100755 --- a/train.py +++ b/train.py @@ -283,12 +283,16 @@ def main(): else: _logger.info('Training with a single process on 1 device.') - random_seed(args.seed, dev_env.global_rank) - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA) + train_state, train_cfg = setup_train_task(args, dev_env, mixup_active) + # Set random seeds across ranks differently for train + # FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS? + random_seed(args.seed, dev_env.global_rank) + data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active) # setup checkpoint saver