Added set_deterministic function

This function can be useful for checking the representativeness of experiment.
pull/290/head
romamartyanov 5 years ago
parent 41e1f8d282
commit fa2e5c6f16

@ -23,6 +23,8 @@ from contextlib import suppress
from datetime import datetime from datetime import datetime
from fire import Fire from fire import Fire
from addict import Dict from addict import Dict
import numpy as np
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -89,9 +91,20 @@ def _parse_args(config_path):
return args, args_text return args, args_text
def set_deterministic(seed=42, precision=13):
np.random.seed(seed)
random.seed(seed)
# torch.backends.cudnn.benchmarks = False
# torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.set_printoptions(precision=precision)
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args('configs/train.yaml') args, args_text = _parse_args('configs/train.yaml')
set_deterministic(args.seed)
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False args.distributed = False

Loading…
Cancel
Save