From fa2e5c6f169fa74a976b258a2eea556f77163bc3 Mon Sep 17 00:00:00 2001 From: romamartyanov Date: Sat, 28 Nov 2020 18:04:19 +0300 Subject: [PATCH] Added set_deterministic function This function can be useful for checking the representativeness of experiment. --- train.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/train.py b/train.py index 66f75f20..8a99733b 100755 --- a/train.py +++ b/train.py @@ -23,6 +23,8 @@ from contextlib import suppress from datetime import datetime from fire import Fire from addict import Dict +import numpy as np +import random import torch import torch.nn as nn @@ -89,9 +91,20 @@ def _parse_args(config_path): return args, args_text +def set_deterministic(seed=42, precision=13): + np.random.seed(seed) + random.seed(seed) + # torch.backends.cudnn.benchmarks = False + # torch.backends.cudnn.deterministic = True + torch.cuda.manual_seed_all(seed) + torch.manual_seed(seed) + torch.set_printoptions(precision=precision) + + def main(): setup_default_logging() args, args_text = _parse_args('configs/train.yaml') + set_deterministic(args.seed) args.prefetcher = not args.no_prefetcher args.distributed = False