diff --git a/train.py b/train.py index 84d8b2ea..d16cae1b 100755 --- a/train.py +++ b/train.py @@ -250,6 +250,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998, # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') +parser.add_argument('--deterministic', action='store_true', default=False, + help="Whether to use deterministic algorithms (may slightly decrease run time performance)") parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', @@ -353,6 +355,8 @@ def main(): "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) + # NOTE we want to make sure deterministic is set after random_seed + torch.backends.cudnn.deterministic = args.deterministic model = create_model( args.model,