|
|
@ -325,7 +325,7 @@ group.add_argument('--amp-dtype', default='float16', type=str,
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
help='lower precision AMP dtype (default: float16)')
|
|
|
|
group.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
group.add_argument('--amp-impl', default='native', type=str,
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
help='AMP impl to use, "native" or "apex" (default: native)')
|
|
|
|
group.add_argument('--iters_to_accumulate', default='1', type=int,
|
|
|
|
group.add_argument('--iters-to-accumulate', default='1', type=int,
|
|
|
|
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
|
|
|
|
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
|
|
|
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
group.add_argument('--no-ddp-bb', action='store_true', default=False,
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
|
help='Force broadcast buffers for native DDP to off.')
|
|
|
@ -715,7 +715,7 @@ def main():
|
|
|
|
f.write(args_text)
|
|
|
|
f.write(args_text)
|
|
|
|
|
|
|
|
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|
updates_per_epoch = len(loader_train)
|
|
|
|
updates_per_epoch = len(loader_train) // args.iters_to_accumulate
|
|
|
|
lr_scheduler, num_epochs = create_scheduler_v2(
|
|
|
|
lr_scheduler, num_epochs = create_scheduler_v2(
|
|
|
|
optimizer,
|
|
|
|
optimizer,
|
|
|
|
**scheduler_kwargs(args),
|
|
|
|
**scheduler_kwargs(args),
|
|
|
|