Gradient accumulation included into training script

Added parameter iters_to_accumulate to perform gradient accumulation
pull/1590/head
Lorenzo Baraldi 2 years ago
parent e7da205345
commit e09b4d5c7f

@ -17,9 +17,10 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler: class ApexScaler:
state_dict_key = "amp" state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, batch_idx=0, iters_to_accumulate=1):
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph) scaled_loss.backward(create_graph=create_graph)
if (batch_idx + 1) % iters_to_accumulate == 0:
if clip_grad is not None: if clip_grad is not None:
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step() optimizer.step()
@ -39,14 +40,17 @@ class NativeScaler:
def __init__(self): def __init__(self):
self._scaler = torch.cuda.amp.GradScaler() self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False,
batch_idx=0, iters_to_accumulate=1):
self._scaler.scale(loss).backward(create_graph=create_graph) self._scaler.scale(loss).backward(create_graph=create_graph)
if (batch_idx + 1) % iters_to_accumulate == 0:
if clip_grad is not None: if clip_grad is not None:
assert parameters is not None assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer) self._scaler.step(optimizer)
self._scaler.update() self._scaler.update()
optimizer.zero_grad()
def state_dict(self): def state_dict(self):
return self._scaler.state_dict() return self._scaler.state_dict()

@ -325,6 +325,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)') help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str, group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)') help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--iters_to_accumulate', default='1', type=int,
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
group.add_argument('--no-ddp-bb', action='store_true', default=False, group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.') help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False, group.add_argument('--pin-mem', action='store_true', default=False,
@ -841,11 +843,11 @@ def train_one_epoch(
losses_m = utils.AverageMeter() losses_m = utils.AverageMeter()
model.train() model.train()
optimizer.zero_grad()
end = time.time() end = time.time()
num_batches_per_epoch = len(loader) num_batches_per_epoch = len(loader)
last_idx = num_batches_per_epoch - 1 last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch num_updates = epoch * num_batches_per_epoch // args.iters_to_accumulate
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)
@ -858,19 +860,20 @@ def train_one_epoch(
with amp_autocast(): with amp_autocast():
output = model(input) output = model(input)
loss = loss_fn(output, target) loss = loss_fn(output, target) / args.iters_to_accumulate
if not args.distributed: if not args.distributed:
losses_m.update(loss.item(), input.size(0)) losses_m.update(loss.item() * args.iters_to_accumulate, input.size(0))
optimizer.zero_grad()
if loss_scaler is not None: if loss_scaler is not None:
loss_scaler( loss_scaler(
loss, optimizer, loss, optimizer,
clip_grad=args.clip_grad, clip_grad=args.clip_grad,
clip_mode=args.clip_mode, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order create_graph=second_order,
batch_idx=batch_idx,
iters_to_accumulate=args.iters_to_accumulate
) )
else: else:
loss.backward(create_graph=second_order) loss.backward(create_graph=second_order)
@ -878,23 +881,24 @@ def train_one_epoch(
utils.dispatch_clip_grad( utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode), model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, value=args.clip_grad,
mode=args.clip_mode mode=args.clip_mode,
) )
if (batch_idx + 1) % args.iters_to_accumulate == 0:
optimizer.step() optimizer.step()
optimizer.zero_grad()
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
torch.cuda.synchronize() torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups] lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl) lr = sum(lrl) / len(lrl)
if args.distributed: if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size) reduced_loss = utils.reduce_tensor(loss.data * args.iters_to_accumulate, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
if utils.is_primary(args): if utils.is_primary(args):
@ -928,6 +932,8 @@ def train_one_epoch(
last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx) saver.save_recovery(epoch, batch_idx=batch_idx)
if (batch_idx + 1) % args.iters_to_accumulate == 0:
num_updates += 1
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save