|
|
|
@ -23,6 +23,8 @@ from contextlib import suppress
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from fire import Fire
|
|
|
|
|
from addict import Dict
|
|
|
|
|
import numpy as np
|
|
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -89,9 +91,20 @@ def _parse_args(config_path):
|
|
|
|
|
return args, args_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_deterministic(seed=42, precision=13):
|
|
|
|
|
np.random.seed(seed)
|
|
|
|
|
random.seed(seed)
|
|
|
|
|
# torch.backends.cudnn.benchmarks = False
|
|
|
|
|
# torch.backends.cudnn.deterministic = True
|
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
torch.set_printoptions(precision=precision)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
args, args_text = _parse_args('configs/train.yaml')
|
|
|
|
|
set_deterministic(args.seed)
|
|
|
|
|
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
args.distributed = False
|
|
|
|
|