diff --git a/train.py b/train.py index b7772860..9226cdf6 100755 --- a/train.py +++ b/train.py @@ -325,7 +325,7 @@ group.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') group.add_argument('--amp-impl', default='native', type=str, help='AMP impl to use, "native" or "apex" (default: native)') -group.add_argument('--iters_to_accumulate', default='1', type=int, +group.add_argument('--iters-to-accumulate', default='1', type=int, help='number of batches evaluated before performing an optimizer step. Used for Gradient accumulation') group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') @@ -715,7 +715,7 @@ def main(): f.write(args_text) # setup learning rate schedule and starting epoch - updates_per_epoch = len(loader_train) + updates_per_epoch = len(loader_train) // args.iters_to_accumulate lr_scheduler, num_epochs = create_scheduler_v2( optimizer, **scheduler_kwargs(args),