Fix model init for XLA, remove some prints.

pull/1239/head
Ross Wightman 3 years ago
parent 6d90fcf282
commit cbd4ee737f

@ -90,8 +90,6 @@ class DeviceEnv:
pass # NO-OP for non-XLA devices
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
print(len(tensor), type(tensor))
print(tensor.shape)
dist.all_reduce(tensor, op=op)
if average:
tensor.div_(self.world_size)

@ -23,12 +23,12 @@ from .device_env import DeviceEnv, DeviceEnvType, TensorList
_PT_TO_XM_OP = {
ReduceOp.SUM: 'sum',
ReduceOp.PRODUCT: 'prod',
ReduceOp.MIN: 'min',
ReduceOp.MAX: 'max',
ReduceOp.BAND: 'and',
ReduceOp.BOR: 'or',
ReduceOp.SUM: xm.REDUCE_SUM,
ReduceOp.PRODUCT: xm.REDUCE_MUL,
ReduceOp.MIN: xm.REDUCE_MIN,
ReduceOp.MAX: xm.REDUCE_MAX,
ReduceOp.BAND: xm.REDUCE_AND,
ReduceOp.BOR: xm.REDUCE_OR,
}
@ -77,20 +77,16 @@ class DeviceEnvXla(DeviceEnv):
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op]
scale = 1.0
if average:
scale /= self.world_size
scale = 1.0 / self.world_size if average else 1.0
return xm.all_reduce(op, tensor, scale=scale)
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
op = _PT_TO_XM_OP[op]
scale = 1.0
scale = 1.0 / self.world_size if average else 1.0
wrapped = False
if isinstance(tensor, torch.Tensor):
tensor = [tensor] # bare tensors are not operated on in-place
wrapped = True
if average:
scale /= self.world_size
xm.all_reduce(op, tensor, scale=scale)
if wrapped:
tensor = tensor[0]

@ -89,6 +89,7 @@ def setup_model_and_optimizer(
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME this is not implemented yet, do a hack job before proper TrainState serialization?
resume_train_checkpoint(
train_state,
resume_path,

@ -283,12 +283,16 @@ def main():
else:
_logger.info('Training with a single process on 1 device.')
random_seed(args.seed, dev_env.global_rank)
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA)
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
# Set random seeds across ranks differently for train
# FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS?
random_seed(args.seed, dev_env.global_rank)
data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active)
# setup checkpoint saver

Loading…
Cancel
Save