Add gradient accumulation option to train.py

option: iters-to-accum(iterations to accmulate)

Gradient accumulation improves training performance(samples/s).
It can reduce the number of parameter sharing between each node.
This option can be helpful when network is bottleneck.

Signed-off-by: Taeksang Kim <voidbag@puzzle-ai.com>
pull/1659/head
Taeksang Kim 2 years ago
parent 7a13be67a5
commit 7f29a46d44

@ -17,12 +17,13 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
optimizer.step()
if need_step:
optimizer.step()
def state_dict(self):
if 'state_dict' in amp.__dict__:
@ -39,14 +40,15 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, need_step=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
if need_step:
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()

@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
group.add_argument('--iters-to-accum', type=int, default=1, metavar='N',
help='The number of iterations to accumulate gradients (default: 1)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
@ -399,6 +401,9 @@ def main():
if args.amp_dtype == 'bfloat16':
amp_dtype = torch.bfloat16
# check if iters_to_accum is smaller than or equal to 0.
assert args.iters_to_accum > 0, 'The argument "iters-to-accum" must be greater than zero.'
utils.random_seed(args.seed, args.rank)
if args.fuser:
@ -851,11 +856,23 @@ def train_one_epoch(
model.train()
end = time.time()
num_batches_per_epoch = len(loader)
last_idx = num_batches_per_epoch - 1
num_batches_per_epoch = (len(loader) + args.iters_to_accum - 1) // args.iters_to_accum
last_idx = len(loader) - 1
last_iters_to_accum = len(loader) % args.iters_to_accum
last_idx_to_accum = len(loader) - last_iters_to_accum
num_updates = epoch * num_batches_per_epoch
optimizer.zero_grad()
num_step_samples = 0
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
iters_to_accum = args.iters_to_accum
if batch_idx >= last_idx_to_accum:
iters_to_accum = last_iters_to_accum
need_step = False
if (batch_idx + 1) % args.iters_to_accum == 0 or last_batch:
need_step =True
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.to(device), target.to(device)
@ -864,82 +881,101 @@ def train_one_epoch(
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
loss /= iters_to_accum
return loss
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order
)
if need_step is not True and hasattr(model, "no_sync"):
with model.no_sync():
loss = _forward()
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode
)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
loss = _forward()
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if utils.is_primary(args):
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m)
if not args.distributed:
losses_m.update(loss.item() * iters_to_accum, input.size(0))
def _backward():
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_step=need_step
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode
)
if need_step:
optimizer.step()
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
num_step_samples += input.size(0)
if need_step is not True and hasattr(model, "no_sync"):
with model.no_sync():
_backward()
else:
_backward()
if need_step:
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model)
end = time.time()
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if (batch_idx // args.iters_to_accum) % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * iters_to_accum, input.size(0))
if utils.is_primary(args):
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=num_step_samples * args.world_size / batch_time_m.val,
rate_avg=num_step_samples * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m)
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True
)
if saver is not None and args.recovery_interval and (
(batch_idx // args.iters_to_accum + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
num_step_samples = 0
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):

Loading…
Cancel
Save