|
|
|
@ -101,11 +101,12 @@ default_cfgs = dict(
|
|
|
|
|
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
|
|
|
|
|
|
|
|
|
nfnet_l0a=_dcfg(
|
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
|
|
|
|
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
|
|
|
|
nfnet_l0b=_dcfg(
|
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
|
|
|
|
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
|
|
|
|
nfnet_l0c=_dcfg(
|
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0c-ad1045c2.pth',
|
|
|
|
|
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
nf_regnet_b0=_dcfg(
|
|
|
|
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
|
|
|
@ -376,9 +377,9 @@ class NormFreeBlock(nn.Module):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
|
|
|
|
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True):
|
|
|
|
|
stem_stride = 2
|
|
|
|
|
stem_feature = dict(num_chs=out_chs, reduction=2, module='')
|
|
|
|
|
stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
|
|
|
|
|
stem = OrderedDict()
|
|
|
|
|
assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
|
|
|
|
|
if 'deep' in stem_type:
|
|
|
|
@ -388,14 +389,14 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None):
|
|
|
|
|
stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
|
|
|
|
|
strides = (2, 1, 1, 2)
|
|
|
|
|
stem_stride = 4
|
|
|
|
|
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act4')
|
|
|
|
|
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
|
|
|
|
|
else:
|
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
|
|
|
|
|
else:
|
|
|
|
|
stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
|
|
|
|
|
strides = (2, 1, 1)
|
|
|
|
|
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.act3')
|
|
|
|
|
stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
|
|
|
|
|
last_idx = len(stem_chs) - 1
|
|
|
|
|
for i, (c, s) in enumerate(zip(stem_chs, strides)):
|
|
|
|
|
stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
|
|
|
|
@ -477,7 +478,7 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
self.stem, stem_stride, stem_feat = create_stem(
|
|
|
|
|
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
self.feature_info = [stem_feat] if stem_stride == 4 else []
|
|
|
|
|
self.feature_info = [stem_feat]
|
|
|
|
|
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
|
net_stride = stem_stride
|
|
|
|
@ -486,8 +487,6 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
stages = []
|
|
|
|
|
for stage_idx, stage_depth in enumerate(cfg.depths):
|
|
|
|
|
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
|
|
|
|
|
if stride == 2:
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1')]
|
|
|
|
|
if net_stride >= output_stride and stride > 1:
|
|
|
|
|
dilation *= stride
|
|
|
|
|
stride = 1
|
|
|
|
@ -522,6 +521,7 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
|
|
|
|
|
first_dilation = dilation
|
|
|
|
|
prev_chs = out_chs
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
|
|
|
|
|
stages += [nn.Sequential(*blocks)]
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
@ -529,11 +529,11 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
# The paper NFRegNet models have an EfficientNet-like final head convolution.
|
|
|
|
|
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
|
|
|
|
|
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
|
|
|
|
|
self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
|
|
|
|
|
else:
|
|
|
|
|
self.num_features = prev_chs
|
|
|
|
|
self.final_conv = nn.Identity()
|
|
|
|
|
self.final_act = act_layer(inplace=cfg.num_features > 0)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
|
|
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
@ -572,10 +572,6 @@ class NormFreeNet(nn.Module):
|
|
|
|
|
def _create_normfreenet(variant, pretrained=False, **kwargs):
|
|
|
|
|
model_cfg = model_cfgs[variant]
|
|
|
|
|
feature_cfg = dict(flatten_sequential=True)
|
|
|
|
|
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
|
|
|
|
|
if 'pool' in model_cfg.stem_type and 'deep' not in model_cfg.stem_type:
|
|
|
|
|
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2 feat for stride 4, 1 layer maxpool stems
|
|
|
|
|
|
|
|
|
|
return build_model_with_cfg(
|
|
|
|
|
NormFreeNet, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
|