|
|
|
@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|
|
|
|
help='Image resize interpolation type (overrides model)')
|
|
|
|
|
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
|
|
|
|
|
help='Input batch size for training (default: 128)')
|
|
|
|
|
group.add_argument('--iters-to-accum', type=int, default=1, metavar='N',
|
|
|
|
|
help='The number of iterations to accumulate gradients (default: 1)')
|
|
|
|
|
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
|
|
|
|
help='Validation batch size override (default: None)')
|
|
|
|
|
group.add_argument('--channels-last', action='store_true', default=False,
|
|
|
|
@ -399,6 +401,9 @@ def main():
|
|
|
|
|
if args.amp_dtype == 'bfloat16':
|
|
|
|
|
amp_dtype = torch.bfloat16
|
|
|
|
|
|
|
|
|
|
# check if iters_to_accum is smaller than or equal to 0.
|
|
|
|
|
assert args.iters_to_accum > 0, 'The argument "iters-to-accum" must be greater than zero.'
|
|
|
|
|
|
|
|
|
|
utils.random_seed(args.seed, args.rank)
|
|
|
|
|
|
|
|
|
|
if args.fuser:
|
|
|
|
@ -851,11 +856,23 @@ def train_one_epoch(
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
num_batches_per_epoch = len(loader)
|
|
|
|
|
last_idx = num_batches_per_epoch - 1
|
|
|
|
|
num_batches_per_epoch = (len(loader) + args.iters_to_accum - 1) // args.iters_to_accum
|
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
|
last_iters_to_accum = len(loader) % args.iters_to_accum
|
|
|
|
|
last_idx_to_accum = len(loader) - last_iters_to_accum
|
|
|
|
|
num_updates = epoch * num_batches_per_epoch
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
num_step_samples = 0
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
iters_to_accum = args.iters_to_accum
|
|
|
|
|
if batch_idx >= last_idx_to_accum:
|
|
|
|
|
iters_to_accum = last_iters_to_accum
|
|
|
|
|
need_step = False
|
|
|
|
|
if (batch_idx + 1) % args.iters_to_accum == 0 or last_batch:
|
|
|
|
|
need_step =True
|
|
|
|
|
|
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
|
|
if not args.prefetcher:
|
|
|
|
|
input, target = input.to(device), target.to(device)
|
|
|
|
@ -864,82 +881,101 @@ def train_one_epoch(
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
with amp_autocast():
|
|
|
|
|
output = model(input)
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
def _forward():
|
|
|
|
|
with amp_autocast():
|
|
|
|
|
output = model(input)
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
loss /= iters_to_accum
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
if not args.distributed:
|
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
if loss_scaler is not None:
|
|
|
|
|
loss_scaler(
|
|
|
|
|
loss, optimizer,
|
|
|
|
|
clip_grad=args.clip_grad,
|
|
|
|
|
clip_mode=args.clip_mode,
|
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
|
create_graph=second_order
|
|
|
|
|
)
|
|
|
|
|
if need_step is not True and hasattr(model, "no_sync"):
|
|
|
|
|
with model.no_sync():
|
|
|
|
|
loss = _forward()
|
|
|
|
|
else:
|
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
|
if args.clip_grad is not None:
|
|
|
|
|
utils.dispatch_clip_grad(
|
|
|
|
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
|
value=args.clip_grad,
|
|
|
|
|
mode=args.clip_mode
|
|
|
|
|
)
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
if model_ema is not None:
|
|
|
|
|
model_ema.update(model)
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
num_updates += 1
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
|
loss = _forward()
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
|
|
|
|
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
|
|
|
|
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
|
|
|
|
'LR: {lr:.3e} '
|
|
|
|
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
|
|
|
epoch,
|
|
|
|
|
batch_idx, len(loader),
|
|
|
|
|
100. * batch_idx / last_idx,
|
|
|
|
|
loss=losses_m,
|
|
|
|
|
batch_time=batch_time_m,
|
|
|
|
|
rate=input.size(0) * args.world_size / batch_time_m.val,
|
|
|
|
|
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
|
|
|
|
lr=lr,
|
|
|
|
|
data_time=data_time_m)
|
|
|
|
|
if not args.distributed:
|
|
|
|
|
losses_m.update(loss.item() * iters_to_accum, input.size(0))
|
|
|
|
|
|
|
|
|
|
def _backward():
|
|
|
|
|
if loss_scaler is not None:
|
|
|
|
|
loss_scaler(
|
|
|
|
|
loss, optimizer,
|
|
|
|
|
clip_grad=args.clip_grad,
|
|
|
|
|
clip_mode=args.clip_mode,
|
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
|
create_graph=second_order,
|
|
|
|
|
need_step=need_step
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.save_images and output_dir:
|
|
|
|
|
torchvision.utils.save_image(
|
|
|
|
|
input,
|
|
|
|
|
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
|
|
|
|
padding=0,
|
|
|
|
|
normalize=True
|
|
|
|
|
else:
|
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
|
if args.clip_grad is not None:
|
|
|
|
|
utils.dispatch_clip_grad(
|
|
|
|
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
|
value=args.clip_grad,
|
|
|
|
|
mode=args.clip_mode
|
|
|
|
|
)
|
|
|
|
|
if need_step:
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
if saver is not None and args.recovery_interval and (
|
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
|
num_step_samples += input.size(0)
|
|
|
|
|
if need_step is not True and hasattr(model, "no_sync"):
|
|
|
|
|
with model.no_sync():
|
|
|
|
|
_backward()
|
|
|
|
|
else:
|
|
|
|
|
_backward()
|
|
|
|
|
if need_step:
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
if model_ema is not None:
|
|
|
|
|
model_ema.update(model)
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
num_updates += 1
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
|
|
|
|
|
if (batch_idx // args.iters_to_accum) % args.log_interval == 0:
|
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
losses_m.update(reduced_loss.item() * iters_to_accum, input.size(0))
|
|
|
|
|
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
|
|
|
|
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
|
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
|
|
|
|
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
|
|
|
|
'LR: {lr:.3e} '
|
|
|
|
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
|
|
|
epoch,
|
|
|
|
|
batch_idx, len(loader),
|
|
|
|
|
100. * batch_idx / last_idx,
|
|
|
|
|
loss=losses_m,
|
|
|
|
|
batch_time=batch_time_m,
|
|
|
|
|
rate=num_step_samples * args.world_size / batch_time_m.val,
|
|
|
|
|
rate_avg=num_step_samples * args.world_size / batch_time_m.avg,
|
|
|
|
|
lr=lr,
|
|
|
|
|
data_time=data_time_m)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if args.save_images and output_dir:
|
|
|
|
|
torchvision.utils.save_image(
|
|
|
|
|
input,
|
|
|
|
|
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
|
|
|
|
padding=0,
|
|
|
|
|
normalize=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if saver is not None and args.recovery_interval and (
|
|
|
|
|
(batch_idx // args.iters_to_accum + 1) % args.recovery_interval == 0):
|
|
|
|
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
|
num_step_samples = 0
|
|
|
|
|
end = time.time()
|
|
|
|
|
# end for
|
|
|
|
|
|
|
|
|
|
if hasattr(optimizer, 'sync_lookahead'):
|
|
|
|
|