Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu

pull/1414/head
Ross Wightman 2 years ago
commit 749856cf25

@ -23,6 +23,39 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### March 21, 2022
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch `0.5.x` or a previous 0.5.x release can be used if stability is required.
* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* `regnety_040` - 82.3 @ 224, 82.96 @ 288
* `regnety_064` - 83.0 @ 224, 83.65 @ 288
* `regnety_080` - 83.17 @ 224, 83.86 @ 288
* `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act)
* `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act)
* `regnetz_040` - 83.67 @ 256, 84.25 @ 320
* `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)
* `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)
* `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)
* `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS)
* `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS)
* `xception41p` - 82 @ 299 (timm pre-act)
* `xception65` - 83.17 @ 299
* `xception65p` - 83.14 @ 299 (timm pre-act)
* `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288
* `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288
* `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets (
* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception
* Grouped conv support added to EfficientNet family
* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler
* Gradient checkpointing support added to many models
* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head`
* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head`
### Feb 2, 2022 ### Feb 2, 2022
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) * [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. * I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.

@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7])
### Select specific feature levels or limit the stride ### Select specific feature levels or limit the stride
There are to additional creation arguments impacting the output features. There are two additional creation arguments impacting the output features.
* `out_indices` selects which indices to output * `out_indices` selects which indices to output
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW) * `output_stride` limits the feature output stride of the network (also works in classification mode BTW)

@ -157,15 +157,13 @@ default_cfgs = {
'regnetz_b16_evos': _cfgr( 'regnetz_b16_evos': _cfgr(
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
crop_pct=0.94),
'regnetz_c16_evos': _cfgr( 'regnetz_c16_evos': _cfgr(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_c16_evos_ch-d8311942.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95),
'regnetz_d8_evos': _cfgr( 'regnetz_d8_evos': _cfgr(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
'regnetz_e8_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
} }
@ -658,24 +656,6 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
), ),
regnetz_e8_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
),
stem_chs=64,
stem_type='deep',
stem_pool='',
downsample='',
num_features=2048,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
norm_layer=partial(EvoNorm2dS0a, group_size=16),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
) )
@register_model @register_model
@ -920,13 +900,6 @@ def regnetz_d8_evos(pretrained=False, **kwargs):
return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
@register_model
def regnetz_e8_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_e8_evos', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence): if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,) stage_blocks_cfg = (stage_blocks_cfg,)
@ -1031,9 +1004,10 @@ class BottleneckBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - kxk - 1x1 """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, def __init__(
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.): downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1088,9 +1062,10 @@ class DarkBlock(nn.Module):
for more optimal compute. for more optimal compute.
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
drop_path_rate=0.): downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1138,9 +1113,10 @@ class EdgeBlock(nn.Module):
FIXME is there a more common 3x3 + 1x1 conv block to name this after? FIXME is there a more common 3x3 + 1x1 conv block to name this after?
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
drop_block=None, drop_path_rate=0.): downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__() super(EdgeBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio) mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1185,8 +1161,9 @@ class RepVggBlock(nn.Module):
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, def __init__(
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__() super(RepVggBlock, self).__init__()
layers = layers or LayerFn() layers = layers or LayerFn()
groups = num_groups(group_size, in_chs) groups = num_groups(group_size, in_chs)
@ -1226,9 +1203,10 @@ class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
""" """
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, def __init__(
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True, self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__() super(SelfAttnBlock, self).__init__()
assert layers is not None assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1295,8 +1273,9 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential): class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', def __init__(
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
super().__init__() super().__init__()
assert stride in (2, 4) assert stride in (2, 4)
layers = layers or LayerFn() layers = layers or LayerFn()
@ -1505,11 +1484,13 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
""" """
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, def __init__(
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.): self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False
layers = get_layer_fns(cfg) layers = get_layer_fns(cfg)
if cfg.fixed_input_size: if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model' assert img_size is not None, 'img_size argument is required for fixed input size model'
@ -1540,6 +1521,22 @@ class ByobNet(nn.Module):
# init weights # init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self): def get_classifier(self):
return self.head.fc return self.head.fc
@ -1548,13 +1545,19 @@ class ByobNet(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.final_conv(x) x = self.final_conv(x)
return x return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.forward_head(x)
return x return x

@ -164,8 +164,9 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module): class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = CrossAttention( self.attn = CrossAttention(

@ -157,9 +157,10 @@ class ResBottleneck(nn.Module):
""" ResNe(X)t Bottleneck Block """ ResNe(X)t Bottleneck Block
""" """
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, def __init__(
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__() super(ResBottleneck, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -199,9 +200,10 @@ class DarkBlock(nn.Module):
""" DarkNet Block """ DarkNet Block
""" """
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, def __init__(
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
drop_block=None, drop_path=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -229,9 +231,10 @@ class DarkBlock(nn.Module):
class CrossStage(nn.Module): class CrossStage(nn.Module):
"""Cross Stage.""" """Cross Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., def __init__(
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
block_fn=ResBottleneck, **block_kwargs): groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
block_fn=ResBottleneck, **block_kwargs):
super(CrossStage, self).__init__() super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
@ -280,8 +283,9 @@ class CrossStage(nn.Module):
class DarkStage(nn.Module): class DarkStage(nn.Module):
"""DarkNet stage.""" """DarkNet stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, def __init__(
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
super(DarkStage, self).__init__() super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
@ -387,10 +391,10 @@ class CspNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [ blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None), (r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages.(\d+)', (0,)), (r'^stages\.(\d+)', (0,)),
] ]
) )
return matcher return matcher

@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer):
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False self.distilled_training = False # must set this True to train w/ distillation token
self.init_weights(weight_init) self.init_weights(weight_init)
@ -85,7 +85,7 @@ class VisionTransformerDistilled(VisionTransformer):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed|dist_token', stem=r'^cls_token|pos_embed|patch_embed|dist_token',
blocks=[ blocks=[
(r'^blocks.(\d+)', None), (r'^blocks\.(\d+)', None),
(r'^norm', (99999,))] # final norm w/ last block (r'^norm', (99999,))] # final norm w/ last block
) )

@ -45,8 +45,9 @@ default_cfgs = {
class DenseLayer(nn.Module): class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, def __init__(
drop_rate=0., memory_efficient=False): self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseLayer, self).__init__() super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)), self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d( self.add_module('conv1', nn.Conv2d(
@ -113,8 +114,9 @@ class DenseLayer(nn.Module):
class DenseBlock(nn.ModuleDict): class DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, def __init__(
drop_rate=0., memory_efficient=False): self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
layer = DenseLayer( layer = DenseLayer(
@ -164,8 +166,8 @@ class DenseNet(nn.Module):
def __init__( def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
aa_stem_only=True): memory_efficient=False, aa_stem_only=True):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
@ -252,10 +254,10 @@ class DenseNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^features.conv[012]|features.norm[012]|features.pool[012]', stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]',
blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [ blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features.denseblock(\d+).denselayer(\d+)', None), (r'^features\.denseblock(\d+)\.denselayer(\d+)', None),
(r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer (r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
] ]
) )
return matcher return matcher

@ -323,8 +323,8 @@ class DLA(nn.Module):
stem=r'^base_layer', stem=r'^base_layer',
blocks=r'^level(\d+)' if coarse else [ blocks=r'^level(\d+)' if coarse else [
# an unusual arch, this achieves somewhat more granularity without getting super messy # an unusual arch, this achieves somewhat more granularity without getting super messy
(r'^level(\d+).tree(\d+)', None), (r'^level(\d+)\.tree(\d+)', None),
(r'^level(\d+).root', (2,)), (r'^level(\d+)\.root', (2,)),
(r'^level(\d+)', (1,)) (r'^level(\d+)', (1,))
] ]
) )

@ -243,10 +243,10 @@ class DPN(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^features.conv1', stem=r'^features\.conv1',
blocks=[ blocks=[
(r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None), (r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
(r'^features.conv5_bn_ac', (99999,)) (r'^features\.conv5_bn_ac', (99999,))
] ]
) )
return matcher return matcher

@ -518,7 +518,7 @@ class EfficientNet(nn.Module):
return dict( return dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=[ blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head|bn2', (99999,)) (r'conv_head|bn2', (99999,))
] ]
) )

@ -193,7 +193,7 @@ class GhostNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=[ blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head', (99999,)) (r'conv_head', (99999,))
] ]
) )

@ -184,7 +184,7 @@ class Xception65(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv[12]|bn[12]', stem=r'^conv[12]|bn[12]',
blocks=[ blocks=[
(r'^mid.block(\d+)', None), (r'^mid\.block(\d+)', None),
(r'^block(\d+)', None), (r'^block(\d+)', None),
(r'^conv[345]|bn[345]', (99,)), (r'^conv[345]|bn[345]', (99,)),
], ],

@ -125,6 +125,7 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub specified as source via model identifier # hf-hub specified as source via model identifier
load_from = 'hf-hub' load_from = 'hf-hub'
assert hf_hub_id assert hf_hub_id
pretrained_loc = hf_hub_id
else: else:
# default source == timm or unspecified # default source == timm or unspecified
if pretrained_file: if pretrained_file:
@ -407,16 +408,6 @@ def pretrained_cfg_for_features(pretrained_cfg):
return pretrained_cfg return pretrained_cfg
# def overlay_external_pretrained_cfg(pretrained_cfg, kwargs):
# """ Overlay 'external_pretrained_cfg' in kwargs on top of pretrained_cfg arg.
# """
# external_pretrained_cfg = kwargs.pop('external_pretrained_cfg', None)
# if external_pretrained_cfg:
# pretrained_cfg.pop('url', None) # url should come from external cfg
# pretrained_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
# pretrained_cfg.update(external_pretrained_cfg)
def set_default_kwargs(kwargs, names, pretrained_cfg): def set_default_kwargs(kwargs, names, pretrained_cfg):
for n in names: for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while # for legacy reasons, model __init__args uses img_size + in_chans as separate args while

@ -686,8 +686,8 @@ class HighResolutionNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^conv[12]|bn[12]', stem=r'^conv[12]|bn[12]',
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [ blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
(r'^layer(\d+).(\d+)', None), (r'^layer(\d+)\.(\d+)', None),
(r'^stage(\d+).(\d+)', None), (r'^stage(\d+)\.(\d+)', None),
(r'^transition(\d+)', (99999,)), (r'^transition(\d+)', (99999,)),
], ],
) )

@ -0,0 +1,156 @@
from typing import Optional
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.modules.transformer import _get_activation_fn
def add_ml_decoder_head(model):
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
model.global_pool = nn.Identity()
del model.fc
num_classes = model.num_classes
num_features = model.num_features
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
model.global_pool = nn.Identity()
del model.classifier
num_classes = model.num_classes
num_features = model.num_features
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
del model.head
num_classes = model.num_classes
num_features = model.num_features
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
else:
print("Model code-writing is not aligned currently with ml-decoder")
exit(-1)
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
model.drop_rate = 0
return model
class TransformerDecoderLayerOptimal(nn.Module):
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5) -> None:
super(TransformerDecoderLayerOptimal, self).__init__()
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = torch.nn.functional.relu
super(TransformerDecoderLayerOptimal, self).__setstate__(state)
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt = tgt + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
# @torch.jit.script
# class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries
# self.group_size = group_size
#
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
# torch.Tensor):
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
# out = (h * w).sum(dim=2) + class_embed_b
# out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out
@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes:
embed_len_decoder = num_classes
# switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
# decoder
decoder_dropout = 0.1
num_layers_decoder = 1
dim_feedforward = 2048
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
# non-learnable queries
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
self.query_embed.requires_grad_(False)
# group fully-connected
self.num_classes = num_classes
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
self.duplicate_pooling = torch.nn.Parameter(
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)
def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7]
embedding_spatial = x.flatten(2).transpose(1, 2)
else: # [bs, 197,468]
embedding_spatial = x
embedding_spatial_786 = self.embed_standart(embedding_spatial)
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
bs = embedding_spatial_786.shape[0]
query_embed = self.query_embed.weight
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
h = h.transpose(0, 1)
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap)
h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias
logits = h_out
return logits

@ -496,7 +496,7 @@ class Levit(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
return matcher return matcher
@ -539,7 +539,7 @@ class LevitDistilled(Levit):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity() self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False self.distilled_training = False # must set this True to train w/ distillation token
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self): def get_classifier(self):

@ -291,7 +291,7 @@ class MlpMixer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', # stem and embed stem=r'^stem', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

@ -171,7 +171,7 @@ class MobileNetV3(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)' blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
) )
@torch.jit.ignore @torch.jit.ignore

@ -334,8 +334,8 @@ class Nest(nn.Module):
matcher = dict( matcher = dict(
stem=r'^patch_embed', # stem and embed stem=r'^patch_embed', # stem and embed
blocks=[ blocks=[
(r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None), (r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None),
(r'^levels.(\d+).(?:pool|pos_embed)', (0,)), (r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )

@ -194,7 +194,6 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg return cfg
model_cfgs = dict( model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights # NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
@ -550,7 +549,7 @@ class NormFreeNet(nn.Module):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=[ blocks=[
(r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None), (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,)) (r'^final_conv', (99999,))
] ]
) )

@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers' A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
- https://arxiv.org/abs/2103.16302 - https://arxiv.org/abs/2103.16302
""" """
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, def __init__(
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token', self, img_size, patch_size, stride, base_dims, depth, heads,
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
super(PoolingVisionTransformer, self).__init__() super(PoolingVisionTransformer, self).__init__()
assert global_pool in ('token',) assert global_pool in ('token',)
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
self.head_dist = None self.head_dist = None
if distilled: if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported' assert not enable, 'gradient checkpointing not supported'
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.norm(cls_tokens) cls_tokens = self.norm(cls_tokens)
return cls_tokens return cls_tokens
def forward(self, x): def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
x = self.forward_features(x)
if self.head_dist is not None: if self.head_dist is not None:
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple assert self.global_pool == 'token'
if self.training and not torch.jit.is_scripting(): x, x_dist = x[:, 0], x[:, 1]
if not pre_logits:
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist return x, x_dist
else: else:
# during standard train / finetune, inference average the classifier predictions
return (x + x_dist) / 2 return (x + x_dist) / 2
else: else:
return self.head(x[:, 0]) if self.global_pool == 'token':
x = x[:, 0]
if not pre_logits:
x = self.head(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):

@ -137,9 +137,15 @@ default_cfgs = dict(
regnety_032=_cfg( regnety_032=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)), crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), regnety_040=_cfg(
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth',
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_080=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
regnety_160=_cfg( regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
@ -147,12 +153,20 @@ default_cfgs = dict(
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''), regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(url='', first_conv='stem'), regnetv_040=_cfg(
regnetv_064=_cfg(url='', first_conv='stem'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'),
regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'),
regnetz_005=_cfg(url=''), regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), regnetz_040=_cfg(
regnetz_040h=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
regnetz_040h=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
) )
@ -444,7 +458,7 @@ class RegNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -106,7 +106,7 @@ default_cfgs = {
'resnext101_32x4d': _cfg(url=''), 'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg( 'resnext101_64x4d': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnext101_64x4d_c-0d0e0cc0.pth',
interpolation='bicubic', crop_pct=1.0, test_input_size=(3, 288, 288)), interpolation='bicubic', crop_pct=1.0, test_input_size=(3, 288, 288)),
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
@ -197,8 +197,8 @@ default_cfgs = {
url='', url='',
interpolation='bicubic'), interpolation='bicubic'),
'seresnext101_32x8d': _cfg( 'seresnext101_32x8d': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth',
interpolation='bicubic'), interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
'senet154': _cfg( 'senet154': _cfg(
url='', url='',
interpolation='bicubic', interpolation='bicubic',
@ -283,7 +283,7 @@ default_cfgs = {
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs200': _cfg( 'resnetrs200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetrs200_c-6b698b88.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnetrs270': _cfg( 'resnetrs270': _cfg(
@ -315,9 +315,10 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -379,9 +380,10 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, def __init__(
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality) width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -561,48 +563,35 @@ class ResNet(nn.Module):
Parameters Parameters
---------- ----------
block : Block block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
Class for the residual block. Options are BasicBlockGl, BottleneckGl. layers : list of int, number of layers in each block
layers : list of int num_classes : int, default 1000, number of classification classes.
Numbers of layers in each block in_chans : int, default 3, number of input (color) channels.
num_classes : int, default 1000 output_stride : int, default 32, output stride of the network, 32, 16, or 8.
Number of classification classes. global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
in_chans : int, default 3 cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
Number of input (color) channels. base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
cardinality : int, default 1 stem_width : int, default 64, number of channels in stem convolutions
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64
Number of channels in stem convolutions
stem_type : str, default '' stem_type : str, default ''
The type of stem: The type of stem:
* '', default - a single 7x7 conv with a width of stem_width * '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1 block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
1 for all archs except senets, where 2 down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
down_kernel_size: int, default 1 avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : nn.Module, activation layer act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0. drop_rate : float, default 0. Dropout probability before classifier, for training
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
""" """
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None): drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
super(ResNet, self).__init__() super(ResNet, self).__init__()
block_args = block_args or dict() block_args = block_args or dict()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
@ -712,12 +701,15 @@ class ResNet(nn.Module):
x = self.layer4(x) x = self.layer4(x)
return x return x
def forward(self, x): def forward_head(self, x, pre_logits: bool = False):
x = self.forward_features(x)
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x) return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x return x

@ -122,13 +122,13 @@ default_cfgs = {
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_gn': _cfg( 'resnetv2_50d_gn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_gn_ah-c415c11a.pth',
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
'resnetv2_50d_evob': _cfg( 'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos0': _cfg( 'resnetv2_50d_evos': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_evos_ah-7c4dd548.pth',
'resnetv2_50d_evos1': _cfg( interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_frn': _cfg( 'resnetv2_50d_frn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
} }
@ -411,8 +411,8 @@ class ResNetV2(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [ blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None), (r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )
@ -693,21 +693,13 @@ def resnetv2_50d_evob(pretrained=False, **kwargs):
@register_model @register_model
def resnetv2_50d_evos0(pretrained=False, **kwargs): def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_50d_evos0', pretrained=pretrained, 'resnetv2_50d_evos', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos1(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos1', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16),
stem_type='deep', avg_down=True, **kwargs)
@register_model @register_model
def resnetv2_50d_frn(pretrained=False, **kwargs): def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(

@ -173,7 +173,7 @@ class ReXNetV1(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^features.(\d+)', blocks=r'^features\.(\d+)',
) )
return matcher return matcher

@ -360,7 +360,7 @@ class SENet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)') matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore

@ -525,9 +525,9 @@ class SwinTransformer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers.(\d+)' if coarse else [ blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers.(\d+).downsample', (0,)), (r'^layers\.(\d+).downsample', (0,)),
(r'^layers.(\d+).\w+.(\d+)', None), (r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)), (r'^norm', (99999,)),
] ]
) )

@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial. GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
* num_heads per stage is not detailed for Huge and Giant model variants * num_heads per stage is not detailed for Huge and Giant model variants
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts * 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme
Noteworthy additions over official Swin v1: Noteworthy additions over official Swin v1:
* MLP relative position embedding is looking promising and adapts to different image/window sizes * MLP relative position embedding is looking promising and adapts to different image/window sizes
@ -67,27 +68,29 @@ default_cfgs = {
'swin_v2_cr_tiny_384': _cfg( 'swin_v2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_tiny_224': _cfg( 'swin_v2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_tiny_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_small_384': _cfg( 'swin_v2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_small_224': _cfg( 'swin_v2_cr_small_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_384': _cfg( 'swin_v2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_224': _cfg( 'swin_v2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_large_384': _cfg( 'swin_v2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg( 'swin_v2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_huge_384': _cfg( 'swin_v2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_huge_224': _cfg( 'swin_v2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_giant_384': _cfg( 'swin_v2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_giant_224': _cfg( 'swin_v2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0), url="", input_size=(3, 224, 224), crop_pct=0.9),
} }
@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features=meta_hidden_dim, hidden_features=meta_hidden_dim,
out_features=num_heads, out_features=num_heads,
act_layer=nn.ReLU, act_layer=nn.ReLU,
drop=0. # FIXME should we add stochasticity? drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
) )
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
self._make_pair_wise_relative_positions() self._make_pair_wise_relative_positions()
@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module):
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?) # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask() self._make_attention_mask()
@ -393,7 +397,8 @@ class SwinTransformerBlock(nn.Module):
# cyclic shift # cyclic shift
sh, sw = self.shift_size sh, sw = self.shift_size
if any(self.shift_size): do_shift: bool = any(self.shift_size)
if do_shift:
# FIXME PyTorch XLA needs cat impl, roll not lowered # FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1) # x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2) # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
@ -411,7 +416,7 @@ class SwinTransformerBlock(nn.Module):
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
# reverse cyclic shift # reverse cyclic shift
if any(self.shift_size): if do_shift:
# FIXME PyTorch XLA needs cat impl, roll not lowered # FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1) # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2) # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
@ -432,7 +437,7 @@ class SwinTransformerBlock(nn.Module):
# NOTE post-norm branches (op -> norm -> drop) # NOTE post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x))) x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
return x return x
@ -502,8 +507,8 @@ class SwinTransformerStage(nn.Module):
drop_attn (float): Dropout rate of attention map drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path drop_path (float): Dropout in main path
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed sequential_attn (bool): If true sequential self-attention is performed
""" """
@ -520,17 +525,23 @@ class SwinTransformerStage(nn.Module):
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0, drop_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm, norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0, extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False, sequential_attn: bool = False,
) -> None: ) -> None:
super(SwinTransformerStage, self).__init__() super(SwinTransformerStage, self).__init__()
self.downscale: bool = downscale self.downscale: bool = downscale
self.grad_checkpointing: bool = grad_checkpointing self.grad_checkpointing: bool = False
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
def _extra_norm(index):
i = index + 1
if extra_norm_period and i % extra_norm_period == 0:
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
SwinTransformerBlock( SwinTransformerBlock(
@ -543,7 +554,7 @@ class SwinTransformerStage(nn.Module):
drop=drop, drop=drop,
drop_attn=drop_attn, drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False, extra_norm=_extra_norm(index),
sequential_attn=sequential_attn, sequential_attn=sequential_attn,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
@ -605,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0 drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized. Default: False extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed. Default: False sequential_attn (bool): If true sequential self-attention is performed. Default: False
use_deformable (bool): If true deformable block is used. Default: False
""" """
def __init__( def __init__(
@ -626,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm, norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0, extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False, sequential_attn: bool = False,
global_pool: str = 'avg', global_pool: str = 'avg',
weight_init='skip',
**kwargs: Any **kwargs: Any
) -> None: ) -> None:
super(SwinTransformerV2Cr, self).__init__() super(SwinTransformerV2Cr, self).__init__()
@ -643,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module):
self.window_size: int = window_size self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.patch_embed: nn.Module = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer) embed_dim=embed_dim, norm_layer=norm_layer)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
@ -664,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module):
drop=drop_rate, drop=drop_rate,
drop_attn=attn_drop_rate, drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
grad_checkpointing=grad_checkpointing,
extra_norm_period=extra_norm_period, extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn, sequential_attn=sequential_attn,
norm_layer=norm_layer, norm_layer=norm_layer,
) )
@ -673,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module):
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool self.global_pool: str = global_pool
self.head: nn.Module = nn.Linear( self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
# FIXME weight init TBD, PyTorch default init appears to be working well, # current weight init skips custom init and uses pytorch layer defaults, seems to work well
# but differs from usual ViT or Swin init. # FIXME more experiments needed
# named_apply(init_weights, self) if weight_init != 'skip':
named_apply(init_weights, self)
def update_input_size( def update_input_size(
self, self,
@ -709,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module):
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
) )
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore()
def get_classifier(self) -> nn.Module: def get_classifier(self) -> nn.Module:
"""Method returns the classification head of the model. """Method returns the classification head of the model.
Returns: Returns:
head (nn.Module): Current classification head head (nn.Module): Current classification head
""" """
head: nn.Module = self.head return self.head
return head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head """Method results the classification head
@ -727,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module):
self.num_classes: int = num_classes self.num_classes: int = num_classes
if global_pool is not None: if global_pool is not None:
self.global_pool = global_pool self.global_pool = global_pool
self.head: nn.Module = nn.Linear( self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) x = self.patch_embed(x)
@ -747,12 +773,14 @@ class SwinTransformerV2Cr(nn.Module):
def init_weights(module: nn.Module, name: str = ''): def init_weights(module: nn.Module, name: str = ''):
# FIXME WIP # FIXME WIP determining if there's a better weight init
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
if 'qkv' in name: if 'qkv' in name:
# treat the weights of Q, K, V separately # treat the weights of Q, K, V separately
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val) nn.init.uniform_(module.weight, -val, val)
elif 'head' in name:
nn.init.zeros_(module.weight)
else: else:
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
@ -790,6 +818,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
** Experimental, may make default if results are improved. **
"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_small_384(pretrained=False, **kwargs): def swin_v2_cr_small_384(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k""" """Swin-S V2 CR @ 384x384, trained ImageNet-1k"""

@ -217,7 +217,7 @@ class TNT(nn.Module):
matcher = dict( matcher = dict(
stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
blocks=[ blocks=[
(r'^blocks.(\d+)', None), (r'^blocks\.(\d+)', None),
(r'^norm', (99999,)), (r'^norm', (99999,)),
] ]
) )

@ -233,7 +233,7 @@ class TResNet(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)') matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
return matcher return matcher
@torch.jit.ignore @torch.jit.ignore

@ -327,11 +327,11 @@ class Twins(nn.Module):
matcher = dict( matcher = dict(
stem=r'^patch_embeds.0', # stem and embed stem=r'^patch_embeds.0', # stem and embed
blocks=[ blocks=[
(r'^(?:blocks|patch_embeds|pos_block).(\d+)', None), (r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None),
('^norm', (99999,)) ('^norm', (99999,))
] if coarse else [ ] if coarse else [
(r'^blocks.(\d+).(\d+)', None), (r'^blocks\.(\d+)\.(\d+)', None),
(r'^(?:patch_embeds|pos_block).(\d+)', (0,)), (r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]
) )

@ -136,7 +136,7 @@ class VGG(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
# this treats BN layers as separate groups for bn variants, a lot of effort to fix that # this treats BN layers as separate groups for bn variants, a lot of effort to fix that
return dict(stem=r'^features.0', blocks=r'^features.(\d+)') return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):

@ -271,7 +271,7 @@ class Visformer(nn.Module):
return dict( return dict(
stem=r'^patch_embed1|pos_embed1|stem', # stem and embed stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
blocks=[ blocks=[
(r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None), (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
(r'^(?:patch_embed|pos_embed)(\d+)', (0,)), (r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
(r'^norm', (99999,)) (r'^norm', (99999,))
] ]

@ -331,7 +331,7 @@ class VisionTransformer(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))] blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

@ -6,7 +6,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in:
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO - https://arxiv.org/abs/2106.10270
NOTE These hybrid model definitions depend on code in vision_transformer.py. NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane. They were moved here to keep file sizes sane.
@ -359,4 +359,4 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model

@ -327,7 +327,7 @@ class VovNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)', blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -34,12 +34,20 @@ default_cfgs = dict(
xception41=_cfg( xception41=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception65=_cfg( xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
crop_pct=0.94,
),
xception71=_cfg( xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(url=''), xception41p=_cfg(
xception65p=_cfg(url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
crop_pct=0.94,
),
xception65p=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
crop_pct=0.94,
),
) )
@ -213,7 +221,7 @@ class XceptionAligned(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^blocks.(\d+)', blocks=r'^blocks\.(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -412,8 +412,8 @@ class XCiT(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks.(\d+)', blocks=r'^blocks\.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))] cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
) )
@torch.jit.ignore @torch.jit.ignore

@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min min_lr=lr_min
) )
self.noise_range = noise_range_t self.noise_range_t = noise_range_t
self.noise_pct = noise_pct self.noise_pct = noise_pct
self.noise_type = noise_type self.noise_type = noise_type
self.noise_std = noise_std self.noise_std = noise_std
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
self.lr_scheduler.step(metric, epoch) # step the base scheduler self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self.noise_range is not None: if self._is_apply_noise(epoch):
if isinstance(self.noise_range, (list, tuple)): self._apply_noise(epoch)
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
def _apply_noise(self, epoch): def _apply_noise(self, epoch):
g = torch.Generator() noise = self._calculate_noise(epoch)
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
# apply the noise on top of previous LR, cache the old value so we can restore for normal # apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler # stepping of base scheduler

@ -90,21 +90,30 @@ class Scheduler:
param_group[self.param_group_field] = value param_group[self.param_group_field] = value
def _add_noise(self, lrs, t): def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None: if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)): if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else: else:
apply_noise = t >= self.noise_range_t apply_noise = t >= self.noise_range_t
if apply_noise: return apply_noise
g = torch.Generator()
g.manual_seed(self.noise_seed + t) def _calculate_noise(self, t) -> float:
if self.noise_type == 'normal': g = torch.Generator()
while True: g.manual_seed(self.noise_seed + t)
# resample if noise out of percent limit, brute force but shouldn't spin much if self.noise_type == 'normal':
noise = torch.randn(1, generator=g).item() while True:
if abs(noise) < self.noise_pct: # resample if noise out of percent limit, brute force but shouldn't spin much
break noise = torch.randn(1, generator=g).item()
else: if abs(noise) < self.noise_pct:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct return noise
lrs = [v + v * noise for v in lrs] else:
return lrs noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise

@ -89,7 +89,7 @@ parser.add_argument('--crop-pct', default=None, type=float,
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset') help='Override std deviation of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N', parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',

Loading…
Cancel
Save