Gradient accumulation included into training script

Added parameter iters_to_accumulate to perform gradient accumulation
pull/1590/head
Lorenzo Baraldi 2 years ago
parent e7da205345
commit e09b4d5c7f

@ -17,9 +17,10 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, batch_idx=0, iters_to_accumulate=1):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if (batch_idx + 1) % iters_to_accumulate == 0:
if clip_grad is not None:
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()
@ -39,14 +40,17 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False,
batch_idx=0, iters_to_accumulate=1):
self._scaler.scale(loss).backward(create_graph=create_graph)
if (batch_idx + 1) % iters_to_accumulate == 0:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
optimizer.zero_grad()
def state_dict(self):
return self._scaler.state_dict()

@ -325,6 +325,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--iters_to_accumulate', default='1', type=int,
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
@ -841,11 +843,11 @@ def train_one_epoch(
losses_m = utils.AverageMeter()
model.train()
optimizer.zero_grad()
end = time.time()
num_batches_per_epoch = len(loader)
last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
num_updates = epoch * num_batches_per_epoch // args.iters_to_accumulate
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
@ -858,19 +860,20 @@ def train_one_epoch(
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss_fn(output, target) / args.iters_to_accumulate
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
losses_m.update(loss.item() * args.iters_to_accumulate, input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order
create_graph=second_order,
batch_idx=batch_idx,
iters_to_accumulate=args.iters_to_accumulate
)
else:
loss.backward(create_graph=second_order)
@ -878,23 +881,24 @@ def train_one_epoch(
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode
mode=args.clip_mode,
)
if (batch_idx + 1) % args.iters_to_accumulate == 0:
optimizer.step()
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
reduced_loss = utils.reduce_tensor(loss.data * args.iters_to_accumulate, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if utils.is_primary(args):
@ -928,6 +932,8 @@ def train_one_epoch(
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if (batch_idx + 1) % args.iters_to_accumulate == 0:
num_updates += 1
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save