Related lr_scheduler to gradient batch accumulation

Fixed updates_per_epoch in training script
pull/1590/head
Lorenzo Baraldi 2 years ago
parent e09b4d5c7f
commit 29918b4c3f

@ -325,7 +325,7 @@ group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)') help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str, group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)') help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--iters_to_accumulate', default='1', type=int, group.add_argument('--iters-to-accumulate', default='1', type=int,
help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation') help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation')
group.add_argument('--no-ddp-bb', action='store_true', default=False, group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.') help='Force broadcast buffers for native DDP to off.')
@ -715,7 +715,7 @@ def main():
f.write(args_text) f.write(args_text)
# setup learning rate schedule and starting epoch # setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train) updates_per_epoch = len(loader_train) // args.iters_to_accumulate
lr_scheduler, num_epochs = create_scheduler_v2( lr_scheduler, num_epochs = create_scheduler_v2(
optimizer, optimizer,
**scheduler_kwargs(args), **scheduler_kwargs(args),

Loading…
Cancel
Save