From 008f25430b50f6f7a539731f6b736026c2d950e1 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 Sep 2021 19:06:26 +0100 Subject: [PATCH] add deterministic flag + functionality --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index f1c1581e..4221c62d 100755 --- a/train.py +++ b/train.py @@ -242,6 +242,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998, # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') +parser.add_argument('--deterministic', action='store_true', default=False, + help="Whether to use deterministic algorithms (may slightly decrease run time performance)") parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', @@ -345,6 +347,8 @@ def main(): "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) + # NOTE we want to make sure deterministic is set after random_seed + torch.backends.cudnn.deterministic = args.deterministic model = create_model( args.model,