diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 15c718f5..2f6dbbef 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -515,52 +515,6 @@ def create_block(block: Union[str, nn.Module], **kwargs): return _block_registry[block](**kwargs) -# class Stem(nn.Module): -# -# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', -# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): -# super().__init__() -# assert stride in (2, 4) -# if pool: -# assert stride == 4 -# layers = layers or LayerFn() -# -# if isinstance(out_chs, (list, tuple)): -# num_rep = len(out_chs) -# stem_chs = out_chs -# else: -# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1] -# -# self.stride = stride -# stem_strides = [2] + [1] * (num_rep - 1) -# if stride == 4 and not pool: -# # set last conv in stack to be strided if stride == 4 and no pooling layer -# stem_strides[-1] = 2 -# -# num_act = num_rep if num_act is None else num_act -# # if num_act < num_rep, first convs in stack won't have bn + act -# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act -# prev_chs = in_chs -# convs = [] -# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): -# layer_fn = layers.conv_norm_act if na else create_conv2d -# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) -# prev_chs = ch -# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0] -# -# if not pool: -# self.pool = nn.Identity() -# elif 'max' in pool.lower(): -# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity() -# else: -# assert False, "Unknown pooling type" -# -# def forward(self, x): -# x = self.conv(x) -# x = self.pool(x) -# return x - - class Stem(nn.Sequential): def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 1c526e8c..d02e62d2 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -9,4 +9,5 @@ from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg from .model import unwrap_model, get_state_dict from .model_ema import ModelEma, ModelEmaV2 +from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/random.py b/timm/utils/random.py new file mode 100644 index 00000000..a9679983 --- /dev/null +++ b/timm/utils/random.py @@ -0,0 +1,9 @@ +import random +import numpy as np +import torch + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) diff --git a/train.py b/train.py index 89ade4a1..49a93eb4 100755 --- a/train.py +++ b/train.py @@ -329,7 +329,7 @@ def main(): _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") - torch.manual_seed(args.seed + args.rank) + random_seed(args.seed, args.rank) model = create_model( args.model,