From 7f29a46d4467f1119a6c0a90faa3cfb9117787b4 Mon Sep 17 00:00:00 2001 From: Taeksang Kim Date: Mon, 6 Feb 2023 07:32:21 +0900 Subject: [PATCH] Add gradient accumulation option to train.py option: iters-to-accum(iterations to accmulate) Gradient accumulation improves training performance(samples/s). It can reduce the number of parameter sharing between each node. This option can be helpful when network is bottleneck. Signed-off-by: Taeksang Kim --- timm/utils/cuda.py | 12 +-- train.py | 178 +++++++++++++++++++++++++++------------------ 2 files changed, 114 insertions(+), 76 deletions(-) diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 9e7bddf3..d26efb64 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -17,12 +17,13 @@ from .clip_grad import dispatch_clip_grad class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) - optimizer.step() + if need_step: + optimizer.step() def state_dict(self): if 'state_dict' in amp.__dict__: @@ -39,14 +40,15 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True): self._scaler.scale(loss).backward(create_graph=create_graph) if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) - self._scaler.step(optimizer) - self._scaler.update() + if need_step: + self._scaler.step(optimizer) + self._scaler.update() def state_dict(self): return self._scaler.state_dict() diff --git a/train.py b/train.py index 9f450ab8..5471324a 100755 --- a/train.py +++ b/train.py @@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='Input batch size for training (default: 128)') +group.add_argument('--iters-to-accum', type=int, default=1, metavar='N', + help='The number of iterations to accumulate gradients (default: 1)') group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='Validation batch size override (default: None)') group.add_argument('--channels-last', action='store_true', default=False, @@ -399,6 +401,9 @@ def main(): if args.amp_dtype == 'bfloat16': amp_dtype = torch.bfloat16 + # check if iters_to_accum is smaller than or equal to 0. + assert args.iters_to_accum > 0, 'The argument "iters-to-accum" must be greater than zero.' + utils.random_seed(args.seed, args.rank) if args.fuser: @@ -851,11 +856,23 @@ def train_one_epoch( model.train() end = time.time() - num_batches_per_epoch = len(loader) - last_idx = num_batches_per_epoch - 1 + num_batches_per_epoch = (len(loader) + args.iters_to_accum - 1) // args.iters_to_accum + last_idx = len(loader) - 1 + last_iters_to_accum = len(loader) % args.iters_to_accum + last_idx_to_accum = len(loader) - last_iters_to_accum num_updates = epoch * num_batches_per_epoch + + optimizer.zero_grad() + num_step_samples = 0 for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx + iters_to_accum = args.iters_to_accum + if batch_idx >= last_idx_to_accum: + iters_to_accum = last_iters_to_accum + need_step = False + if (batch_idx + 1) % args.iters_to_accum == 0 or last_batch: + need_step =True + data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.to(device), target.to(device) @@ -864,82 +881,101 @@ def train_one_epoch( if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): - output = model(input) - loss = loss_fn(output, target) + def _forward(): + with amp_autocast(): + output = model(input) + loss = loss_fn(output, target) + loss /= iters_to_accum + return loss - if not args.distributed: - losses_m.update(loss.item(), input.size(0)) - - optimizer.zero_grad() - if loss_scaler is not None: - loss_scaler( - loss, optimizer, - clip_grad=args.clip_grad, - clip_mode=args.clip_mode, - parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order - ) + if need_step is not True and hasattr(model, "no_sync"): + with model.no_sync(): + loss = _forward() else: - loss.backward(create_graph=second_order) - if args.clip_grad is not None: - utils.dispatch_clip_grad( - model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, - mode=args.clip_mode - ) - optimizer.step() - - if model_ema is not None: - model_ema.update(model) - - torch.cuda.synchronize() - - num_updates += 1 - batch_time_m.update(time.time() - end) - if last_batch or batch_idx % args.log_interval == 0: - lrl = [param_group['lr'] for param_group in optimizer.param_groups] - lr = sum(lrl) / len(lrl) + loss = _forward() - if args.distributed: - reduced_loss = utils.reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item(), input.size(0)) - - if utils.is_primary(args): - _logger.info( - 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' - 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' - 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' - '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' - 'LR: {lr:.3e} ' - 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( - epoch, - batch_idx, len(loader), - 100. * batch_idx / last_idx, - loss=losses_m, - batch_time=batch_time_m, - rate=input.size(0) * args.world_size / batch_time_m.val, - rate_avg=input.size(0) * args.world_size / batch_time_m.avg, - lr=lr, - data_time=data_time_m) + if not args.distributed: + losses_m.update(loss.item() * iters_to_accum, input.size(0)) + + def _backward(): + if loss_scaler is not None: + loss_scaler( + loss, optimizer, + clip_grad=args.clip_grad, + clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order, + need_step=need_step ) - - if args.save_images and output_dir: - torchvision.utils.save_image( - input, - os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), - padding=0, - normalize=True + else: + loss.backward(create_graph=second_order) + if args.clip_grad is not None: + utils.dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, + mode=args.clip_mode ) + if need_step: + optimizer.step() - if saver is not None and args.recovery_interval and ( - last_batch or (batch_idx + 1) % args.recovery_interval == 0): - saver.save_recovery(epoch, batch_idx=batch_idx) - - if lr_scheduler is not None: - lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + num_step_samples += input.size(0) + if need_step is not True and hasattr(model, "no_sync"): + with model.no_sync(): + _backward() + else: + _backward() + if need_step: + optimizer.zero_grad() + if model_ema is not None: + model_ema.update(model) - end = time.time() + torch.cuda.synchronize() + num_updates += 1 + batch_time_m.update(time.time() - end) + + if (batch_idx // args.iters_to_accum) % args.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + + if args.distributed: + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + losses_m.update(reduced_loss.item() * iters_to_accum, input.size(0)) + + if utils.is_primary(args): + _logger.info( + 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' + 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e} ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, len(loader), + 100. * batch_idx / last_idx, + loss=losses_m, + batch_time=batch_time_m, + rate=num_step_samples * args.world_size / batch_time_m.val, + rate_avg=num_step_samples * args.world_size / batch_time_m.avg, + lr=lr, + data_time=data_time_m) + ) + + if args.save_images and output_dir: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True + ) + + if saver is not None and args.recovery_interval and ( + (batch_idx // args.iters_to_accum + 1) % args.recovery_interval == 0): + saver.save_recovery(epoch, batch_idx=batch_idx) + + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + num_step_samples = 0 + end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'):