diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 9e7bddf3..5983ffcc 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -17,12 +17,13 @@ from .clip_grad import dispatch_clip_grad class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, batch_idx=0, iters_to_accumulate=1): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) - if clip_grad is not None: - dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) - optimizer.step() + if (batch_idx + 1) % iters_to_accumulate == 0: + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() def state_dict(self): if 'state_dict' in amp.__dict__: @@ -39,14 +40,17 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, + batch_idx=0, iters_to_accumulate=1): self._scaler.scale(loss).backward(create_graph=create_graph) - if clip_grad is not None: - assert parameters is not None - self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place - dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) - self._scaler.step(optimizer) - self._scaler.update() + if (batch_idx + 1) % iters_to_accumulate == 0: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() + optimizer.zero_grad() def state_dict(self): return self._scaler.state_dict() diff --git a/train.py b/train.py index e51d7c90..b7772860 100755 --- a/train.py +++ b/train.py @@ -325,6 +325,8 @@ group.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') group.add_argument('--amp-impl', default='native', type=str, help='AMP impl to use, "native" or "apex" (default: native)') +group.add_argument('--iters_to_accumulate', default='1', type=int, + help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') group.add_argument('--pin-mem', action='store_true', default=False, @@ -841,11 +843,11 @@ def train_one_epoch( losses_m = utils.AverageMeter() model.train() - + optimizer.zero_grad() end = time.time() num_batches_per_epoch = len(loader) last_idx = num_batches_per_epoch - 1 - num_updates = epoch * num_batches_per_epoch + num_updates = epoch * num_batches_per_epoch // args.iters_to_accumulate for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) @@ -858,19 +860,20 @@ def train_one_epoch( with amp_autocast(): output = model(input) - loss = loss_fn(output, target) + loss = loss_fn(output, target) / args.iters_to_accumulate if not args.distributed: - losses_m.update(loss.item(), input.size(0)) + losses_m.update(loss.item() * args.iters_to_accumulate, input.size(0)) - optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order + create_graph=second_order, + batch_idx=batch_idx, + iters_to_accumulate=args.iters_to_accumulate ) else: loss.backward(create_graph=second_order) @@ -878,23 +881,24 @@ def train_one_epoch( utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, - mode=args.clip_mode + mode=args.clip_mode, ) - optimizer.step() + if (batch_idx + 1) % args.iters_to_accumulate == 0: + optimizer.step() + optimizer.zero_grad() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() - num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: - reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + reduced_loss = utils.reduce_tensor(loss.data * args.iters_to_accumulate, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if utils.is_primary(args): @@ -928,8 +932,10 @@ def train_one_epoch( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(epoch, batch_idx=batch_idx) - if lr_scheduler is not None: - lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + if (batch_idx + 1) % args.iters_to_accumulate == 0: + num_updates += 1 + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for