Added set_deterministic function

This function can be useful for checking the representativeness of experiment.
pull/290/head
romamartyanov 5 years ago
parent 41e1f8d282
commit fa2e5c6f16

@ -23,6 +23,8 @@ from contextlib import suppress
from datetime import datetime
from fire import Fire
from addict import Dict
import numpy as np
import random
import torch
import torch.nn as nn
@ -89,9 +91,20 @@ def _parse_args(config_path):
return args, args_text
def set_deterministic(seed=42, precision=13):
np.random.seed(seed)
random.seed(seed)
# torch.backends.cudnn.benchmarks = False
# torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.set_printoptions(precision=precision)
def main():
setup_default_logging()
args, args_text = _parse_args('configs/train.yaml')
set_deterministic(args.seed)
args.prefetcher = not args.no_prefetcher
args.distributed = False

Loading…
Cancel
Save