|
|
@ -325,6 +325,8 @@ group.add_argument('--amp-dtype', default='float16', type=str,
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
group.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
group.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
|
|
|
|
group.add_argument('--iters_to_accumulate', default='1', type=int,
|
|
|
|
|
|
|
|
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
|
|
|
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
|
group.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
|
group.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
@ -841,11 +843,11 @@ def train_one_epoch(
|
|
|
|
losses_m = utils.AverageMeter()
|
|
|
|
losses_m = utils.AverageMeter()
|
|
|
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
end = time.time()
|
|
|
|
end = time.time()
|
|
|
|
num_batches_per_epoch = len(loader)
|
|
|
|
num_batches_per_epoch = len(loader)
|
|
|
|
last_idx = num_batches_per_epoch - 1
|
|
|
|
last_idx = num_batches_per_epoch - 1
|
|
|
|
num_updates = epoch * num_batches_per_epoch
|
|
|
|
num_updates = epoch * num_batches_per_epoch // args.iters_to_accumulate
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
@ -858,19 +860,20 @@ def train_one_epoch(
|
|
|
|
|
|
|
|
|
|
|
|
with amp_autocast():
|
|
|
|
with amp_autocast():
|
|
|
|
output = model(input)
|
|
|
|
output = model(input)
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
loss = loss_fn(output, target) / args.iters_to_accumulate
|
|
|
|
|
|
|
|
|
|
|
|
if not args.distributed:
|
|
|
|
if not args.distributed:
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
losses_m.update(loss.item() * args.iters_to_accumulate, input.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
if loss_scaler is not None:
|
|
|
|
if loss_scaler is not None:
|
|
|
|
loss_scaler(
|
|
|
|
loss_scaler(
|
|
|
|
loss, optimizer,
|
|
|
|
loss, optimizer,
|
|
|
|
clip_grad=args.clip_grad,
|
|
|
|
clip_grad=args.clip_grad,
|
|
|
|
clip_mode=args.clip_mode,
|
|
|
|
clip_mode=args.clip_mode,
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
create_graph=second_order
|
|
|
|
create_graph=second_order,
|
|
|
|
|
|
|
|
batch_idx=batch_idx,
|
|
|
|
|
|
|
|
iters_to_accumulate=args.iters_to_accumulate
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
@ -878,23 +881,24 @@ def train_one_epoch(
|
|
|
|
utils.dispatch_clip_grad(
|
|
|
|
utils.dispatch_clip_grad(
|
|
|
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
value=args.clip_grad,
|
|
|
|
value=args.clip_grad,
|
|
|
|
mode=args.clip_mode
|
|
|
|
mode=args.clip_mode,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
optimizer.step()
|
|
|
|
if (batch_idx + 1) % args.iters_to_accumulate == 0:
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
if model_ema is not None:
|
|
|
|
if model_ema is not None:
|
|
|
|
model_ema.update(model)
|
|
|
|
model_ema.update(model)
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
num_updates += 1
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data * args.iters_to_accumulate, args.world_size)
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
if utils.is_primary(args):
|
|
|
@ -928,8 +932,10 @@ def train_one_epoch(
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
|
|
|
saver.save_recovery(epoch, batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
if (batch_idx + 1) % args.iters_to_accumulate == 0:
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
num_updates += 1
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
end = time.time()
|
|
|
|
# end for
|
|
|
|
# end for
|
|
|
|