Remove commented code, add more consistent seed fn

pull/556/head
Ross Wightman 4 years ago
parent 364dd6a58e
commit 7c97e66f7c

@ -515,52 +515,6 @@ def create_block(block: Union[str, nn.Module], **kwargs):
return _block_registry[block](**kwargs) return _block_registry[block](**kwargs)
# class Stem(nn.Module):
#
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
# super().__init__()
# assert stride in (2, 4)
# if pool:
# assert stride == 4
# layers = layers or LayerFn()
#
# if isinstance(out_chs, (list, tuple)):
# num_rep = len(out_chs)
# stem_chs = out_chs
# else:
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
#
# self.stride = stride
# stem_strides = [2] + [1] * (num_rep - 1)
# if stride == 4 and not pool:
# # set last conv in stack to be strided if stride == 4 and no pooling layer
# stem_strides[-1] = 2
#
# num_act = num_rep if num_act is None else num_act
# # if num_act < num_rep, first convs in stack won't have bn + act
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
# prev_chs = in_chs
# convs = []
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
# layer_fn = layers.conv_norm_act if na else create_conv2d
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
# prev_chs = ch
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
#
# if not pool:
# self.pool = nn.Identity()
# elif 'max' in pool.lower():
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
# else:
# assert False, "Unknown pooling type"
#
# def forward(self, x):
# x = self.conv(x)
# x = self.pool(x)
# return x
class Stem(nn.Sequential): class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',

@ -9,4 +9,5 @@ from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg from .misc import natural_key, add_bool_arg
from .model import unwrap_model, get_state_dict from .model import unwrap_model, get_state_dict
from .model_ema import ModelEma, ModelEmaV2 from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed
from .summary import update_summary, get_outdir from .summary import update_summary, get_outdir

@ -0,0 +1,9 @@
import random
import numpy as np
import torch
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)

@ -329,7 +329,7 @@ def main():
_logger.warning("Neither APEX or native Torch AMP is available, using float32. " _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6") "Install NVIDA apex or upgrade to PyTorch 1.6")
torch.manual_seed(args.seed + args.rank) random_seed(args.seed, args.rank)
model = create_model( model = create_model(
args.model, args.model,

Loading…
Cancel
Save