From 80075b0b8a4372c6be82ff4dc88517572d91b94a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 28 Sep 2021 16:37:45 -0700 Subject: [PATCH] Add worker_seeding arg to allow selecting old vs updated data loader worker seed for (old) experiment repeatability --- timm/data/loader.py | 25 ++++++++++++++++++++----- train.py | 5 ++++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index 7d5aa1e5..a02399a3 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -3,8 +3,11 @@ Prefetcher and Fast Collate inspired by NVIDIA APEX example at https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ +import random +from functools import partial +from typing import Callable import torch.utils.data import numpy as np @@ -125,10 +128,20 @@ class PrefetchLoader: self.loader.collate_fn.mixup_enabled = x -def _worker_init(worker_id): +def _worker_init(worker_id, worker_seeding='all'): worker_info = torch.utils.data.get_worker_info() assert worker_info.id == worker_id - np.random.seed(worker_info.seed % (2**32-1)) + if isinstance(worker_seeding, Callable): + seed = worker_seeding(worker_info) + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed % (2 ** 32 - 1)) + else: + assert worker_seeding in ('all', 'part') + # random / torch seed already called in dataloader iter class w/ worker_info.seed + # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed) + if worker_seeding == 'all': + np.random.seed(worker_info.seed % (2 ** 32 - 1)) def create_loader( @@ -162,6 +175,7 @@ def create_loader( tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, + worker_seeding='all', ): re_num_splits = 0 if re_split: @@ -219,8 +233,9 @@ def create_loader( collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, - worker_init_fn=_worker_init, - persistent_workers=persistent_workers) + worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), + persistent_workers=persistent_workers + ) try: loader = loader_class(dataset, **loader_args) except TypeError as e: diff --git a/train.py b/train.py index 55aba416..d95611ad 100755 --- a/train.py +++ b/train.py @@ -252,6 +252,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998, # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') +parser.add_argument('--worker-seeding', type=str, default='all', + help='worker seed mode (default: all)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', @@ -535,7 +537,8 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, - use_multi_epochs_loader=args.use_multi_epochs_loader + use_multi_epochs_loader=args.use_multi_epochs_loader, + worker_seeding=args.worker_seeding, ) loader_eval = create_loader(