Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu

pull/1414/head
Ross Wightman 2 years ago
commit 749856cf25

@ -23,6 +23,39 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### March 21, 2022
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch `0.5.x` or a previous 0.5.x release can be used if stability is required.
* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* `regnety_040` - 82.3 @ 224, 82.96 @ 288
* `regnety_064` - 83.0 @ 224, 83.65 @ 288
* `regnety_080` - 83.17 @ 224, 83.86 @ 288
* `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act)
* `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act)
* `regnetz_040` - 83.67 @ 256, 84.25 @ 320
* `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)
* `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)
* `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)
* `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS)
* `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS)
* `xception41p` - 82 @ 299 (timm pre-act)
* `xception65` - 83.17 @ 299
* `xception65p` - 83.14 @ 299 (timm pre-act)
* `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288
* `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288
* `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets (
* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception
* Grouped conv support added to EfficientNet family
* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler
* Gradient checkpointing support added to many models
* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head`
* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head`
### Feb 2, 2022
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.

@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7])
### Select specific feature levels or limit the stride
There are to additional creation arguments impacting the output features.
There are two additional creation arguments impacting the output features.
* `out_indices` selects which indices to output
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)

@ -157,15 +157,13 @@ default_cfgs = {
'regnetz_b16_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
crop_pct=0.94),
'regnetz_c16_evos': _cfgr(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_c16_evos_ch-d8311942.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95),
'regnetz_d8_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
'regnetz_e8_evos': _cfgr(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
}
@ -658,24 +656,6 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_e8_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
),
stem_chs=64,
stem_type='deep',
stem_pool='',
downsample='',
num_features=2048,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
norm_layer=partial(EvoNorm2dS0a, group_size=16),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
)
@register_model
@ -920,13 +900,6 @@ def regnetz_d8_evos(pretrained=False, **kwargs):
return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
@register_model
def regnetz_e8_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_e8_evos', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
@ -1031,9 +1004,10 @@ class BottleneckBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1088,9 +1062,10 @@ class DarkBlock(nn.Module):
for more optimal compute.
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1138,9 +1113,10 @@ class EdgeBlock(nn.Module):
FIXME is there a more common 3x3 + 1x1 conv block to name this after?
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1185,8 +1161,9 @@ class RepVggBlock(nn.Module):
This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -1226,9 +1203,10 @@ class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1295,8 +1273,9 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
super().__init__()
assert stride in (2, 4)
layers = layers or LayerFn()
@ -1505,11 +1484,13 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
def __init__(
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
@ -1540,6 +1521,22 @@ class ByobNet(nn.Module):
# init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -1548,13 +1545,19 @@ class ByobNet(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.final_conv(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -164,8 +164,9 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(

@ -157,9 +157,10 @@ class ResBottleneck(nn.Module):
""" ResNe(X)t Bottleneck Block
"""
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
def __init__(
self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -199,9 +200,10 @@ class DarkBlock(nn.Module):
""" DarkNet Block
"""
def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
def __init__(
self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
@ -229,9 +231,10 @@ class DarkBlock(nn.Module):
class CrossStage(nn.Module):
"""Cross Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
block_fn=ResBottleneck, **block_kwargs):
def __init__(
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
block_fn=ResBottleneck, **block_kwargs):
super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
@ -280,8 +283,9 @@ class CrossStage(nn.Module):
class DarkStage(nn.Module):
"""DarkNet stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
def __init__(
self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
@ -387,10 +391,10 @@ class CspNet(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
(r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages.(\d+)', (0,)),
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages\.(\d+)', (0,)),
]
)
return matcher

@ -72,7 +72,7 @@ class VisionTransformerDistilled(VisionTransformer):
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
self.distilled_training = False # must set this True to train w/ distillation token
self.init_weights(weight_init)
@ -85,7 +85,7 @@ class VisionTransformerDistilled(VisionTransformer):
return dict(
stem=r'^cls_token|pos_embed|patch_embed|dist_token',
blocks=[
(r'^blocks.(\d+)', None),
(r'^blocks\.(\d+)', None),
(r'^norm', (99999,))] # final norm w/ last block
)

@ -45,8 +45,9 @@ default_cfgs = {
class DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
def __init__(
self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)),
self.add_module('conv1', nn.Conv2d(
@ -113,8 +114,9 @@ class DenseLayer(nn.Module):
class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = DenseLayer(
@ -164,8 +166,8 @@ class DenseNet(nn.Module):
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
aa_stem_only=True):
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
memory_efficient=False, aa_stem_only=True):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
@ -252,10 +254,10 @@ class DenseNet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv[012]|features.norm[012]|features.pool[012]',
blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features.denseblock(\d+).denselayer(\d+)', None),
(r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
stem=r'^features\.conv[012]|features\.norm[012]|features\.pool[012]',
blocks=r'^features\.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features\.denseblock(\d+)\.denselayer(\d+)', None),
(r'^features\.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
]
)
return matcher

@ -323,8 +323,8 @@ class DLA(nn.Module):
stem=r'^base_layer',
blocks=r'^level(\d+)' if coarse else [
# an unusual arch, this achieves somewhat more granularity without getting super messy
(r'^level(\d+).tree(\d+)', None),
(r'^level(\d+).root', (2,)),
(r'^level(\d+)\.tree(\d+)', None),
(r'^level(\d+)\.root', (2,)),
(r'^level(\d+)', (1,))
]
)

@ -243,10 +243,10 @@ class DPN(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv1',
stem=r'^features\.conv1',
blocks=[
(r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None),
(r'^features.conv5_bn_ac', (99999,))
(r'^features\.conv(\d+)' if coarse else r'^features\.conv(\d+)_(\d+)', None),
(r'^features\.conv5_bn_ac', (99999,))
]
)
return matcher

@ -518,7 +518,7 @@ class EfficientNet(nn.Module):
return dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head|bn2', (99999,))
]
)

@ -193,7 +193,7 @@ class GhostNet(nn.Module):
matcher = dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
(r'conv_head', (99999,))
]
)

@ -184,7 +184,7 @@ class Xception65(nn.Module):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=[
(r'^mid.block(\d+)', None),
(r'^mid\.block(\d+)', None),
(r'^block(\d+)', None),
(r'^conv[345]|bn[345]', (99,)),
],

@ -125,6 +125,7 @@ def _resolve_pretrained_source(pretrained_cfg):
# hf-hub specified as source via model identifier
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
@ -407,16 +408,6 @@ def pretrained_cfg_for_features(pretrained_cfg):
return pretrained_cfg
# def overlay_external_pretrained_cfg(pretrained_cfg, kwargs):
# """ Overlay 'external_pretrained_cfg' in kwargs on top of pretrained_cfg arg.
# """
# external_pretrained_cfg = kwargs.pop('external_pretrained_cfg', None)
# if external_pretrained_cfg:
# pretrained_cfg.pop('url', None) # url should come from external cfg
# pretrained_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
# pretrained_cfg.update(external_pretrained_cfg)
def set_default_kwargs(kwargs, names, pretrained_cfg):
for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while

@ -686,8 +686,8 @@ class HighResolutionNet(nn.Module):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
(r'^layer(\d+).(\d+)', None),
(r'^stage(\d+).(\d+)', None),
(r'^layer(\d+)\.(\d+)', None),
(r'^stage(\d+)\.(\d+)', None),
(r'^transition(\d+)', (99999,)),
],
)

@ -0,0 +1,156 @@
from typing import Optional
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.modules.transformer import _get_activation_fn
def add_ml_decoder_head(model):
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
model.global_pool = nn.Identity()
del model.fc
num_classes = model.num_classes
num_features = model.num_features
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
model.global_pool = nn.Identity()
del model.classifier
num_classes = model.num_classes
num_features = model.num_features
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
del model.head
num_classes = model.num_classes
num_features = model.num_features
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
else:
print("Model code-writing is not aligned currently with ml-decoder")
exit(-1)
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
model.drop_rate = 0
return model
class TransformerDecoderLayerOptimal(nn.Module):
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5) -> None:
super(TransformerDecoderLayerOptimal, self).__init__()
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = torch.nn.functional.relu
super(TransformerDecoderLayerOptimal, self).__setstate__(state)
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt = tgt + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
# @torch.jit.script
# class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries
# self.group_size = group_size
#
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
# torch.Tensor):
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
# out = (h * w).sum(dim=2) + class_embed_b
# out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out
@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes:
embed_len_decoder = num_classes
# switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
# decoder
decoder_dropout = 0.1
num_layers_decoder = 1
dim_feedforward = 2048
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
# non-learnable queries
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
self.query_embed.requires_grad_(False)
# group fully-connected
self.num_classes = num_classes
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
self.duplicate_pooling = torch.nn.Parameter(
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)
def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7]
embedding_spatial = x.flatten(2).transpose(1, 2)
else: # [bs, 197,468]
embedding_spatial = x
embedding_spatial_786 = self.embed_standart(embedding_spatial)
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
bs = embedding_spatial_786.shape[0]
query_embed = self.query_embed.weight
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
h = h.transpose(0, 1)
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap)
h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias
logits = h_out
return logits

@ -496,7 +496,7 @@ class Levit(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@ -539,7 +539,7 @@ class LevitDistilled(Levit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
self.distilled_training = False # must set this True to train w/ distillation token
@torch.jit.ignore
def get_classifier(self):

@ -291,7 +291,7 @@ class MlpMixer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

@ -171,7 +171,7 @@ class MobileNetV3(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^conv_stem|bn1',
blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)'
blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
)
@torch.jit.ignore

@ -334,8 +334,8 @@ class Nest(nn.Module):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=[
(r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None),
(r'^levels.(\d+).(?:pool|pos_embed)', (0,)),
(r'^levels\.(\d+)' if coarse else r'^levels\.(\d+)\.transformer_encoder\.(\d+)', None),
(r'^levels\.(\d+)\.(?:pool|pos_embed)', (0,)),
(r'^norm', (99999,))
]
)

@ -194,7 +194,6 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg
model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
@ -550,7 +549,7 @@ class NormFreeNet(nn.Module):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None),
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
(r'^final_conv', (99999,))
]
)

@ -147,9 +147,10 @@ class PoolingVisionTransformer(nn.Module):
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
- https://arxiv.org/abs/2103.16302
"""
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
def __init__(
self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
super(PoolingVisionTransformer, self).__init__()
assert global_pool in ('token',)
@ -193,6 +194,7 @@ class PoolingVisionTransformer(nn.Module):
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
@ -207,6 +209,10 @@ class PoolingVisionTransformer(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@ -231,16 +237,30 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.norm(cls_tokens)
return cls_tokens
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
if self.head_dist is not None:
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
assert self.global_pool == 'token'
x, x_dist = x[:, 0], x[:, 1]
if not pre_logits:
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist
else:
# during standard train / finetune, inference average the classifier predictions
return (x + x_dist) / 2
else:
return self.head(x[:, 0])
if self.global_pool == 'token':
x = x[:, 0]
if not pre_logits:
x = self.head(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):

@ -137,9 +137,15 @@ default_cfgs = dict(
regnety_032=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
regnety_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_040_ra3-670e1166.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_064_ra3-aa26dc7d.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_080=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnety_080_ra3-1fdc4344.pth',
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
regnety_160=_cfg(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
@ -147,12 +153,20 @@ default_cfgs = dict(
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(url='', first_conv='stem'),
regnetv_064=_cfg(url='', first_conv='stem'),
regnetv_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'),
regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
regnetz_040h=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
regnetz_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040_ra3-9007edf5.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
regnetz_040h=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_040h_ra3-f594343b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320)),
)
@ -444,7 +458,7 @@ class RegNet(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)',
)
@torch.jit.ignore

@ -106,7 +106,7 @@ default_cfgs = {
'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnext101_64x4d_c-0d0e0cc0.pth',
interpolation='bicubic', crop_pct=1.0, test_input_size=(3, 288, 288)),
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
@ -197,8 +197,8 @@ default_cfgs = {
url='',
interpolation='bicubic'),
'seresnext101_32x8d': _cfg(
url='',
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth',
interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
'senet154': _cfg(
url='',
interpolation='bicubic',
@ -283,7 +283,7 @@ default_cfgs = {
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetrs200_c-6b698b88.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs270': _cfg(
@ -315,9 +315,10 @@ def create_aa(aa_layer, channels, stride=2, enable=True):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -379,9 +380,10 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -561,48 +563,35 @@ class ResNet(nn.Module):
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int
Numbers of layers in each block
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64
Number of channels in stem convolutions
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int, number of layers in each block
num_classes : int, default 1000, number of classification classes.
in_chans : int, default 3, number of input (color) channels.
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
stem_width : int, default 64, number of channels in stem convolutions
stem_type : str, default ''
The type of stem:
* '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2
down_kernel_size: int, default 1
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
block_reduce_first : int, default 1
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
act_layer : nn.Module, activation layer
norm_layer : nn.Module, normalization layer
aa_layer : nn.Module, anti-aliasing layer
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
drop_rate : float, default 0. Dropout probability before classifier, for training
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None):
def __init__(
self, block, layers, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, block_reduce_first=1,
down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_rate=0.0, drop_path_rate=0., drop_block_rate=0., zero_init_last=True, block_args=None):
super(ResNet, self).__init__()
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
@ -712,12 +701,15 @@ class ResNet(nn.Module):
x = self.layer4(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -122,13 +122,13 @@ default_cfgs = {
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_gn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_gn_ah-c415c11a.pth',
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos0': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos1': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetv2_50d_evos_ah-7c4dd548.pth',
interpolation='bicubic', first_conv='stem.conv1', test_input_size=(3, 288, 288), crop_pct=0.95),
'resnetv2_50d_frn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
}
@ -411,8 +411,8 @@ class ResNetV2(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm', (99999,))
]
)
@ -693,21 +693,13 @@ def resnetv2_50d_evob(pretrained=False, **kwargs):
@register_model
def resnetv2_50d_evos0(pretrained=False, **kwargs):
def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos0', pretrained=pretrained,
'resnetv2_50d_evos', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos1(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos1', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16),
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2(

@ -173,7 +173,7 @@ class ReXNetV1(nn.Module):
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^features.(\d+)',
blocks=r'^features\.(\d+)',
)
return matcher

@ -360,7 +360,7 @@ class SENet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)')
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore

@ -525,9 +525,9 @@ class SwinTransformer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers.(\d+)' if coarse else [
(r'^layers.(\d+).downsample', (0,)),
(r'^layers.(\d+).\w+.(\d+)', None),
blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)),
]
)

@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
* num_heads per stage is not detailed for Huge and Giant model variants
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme
Noteworthy additions over official Swin v1:
* MLP relative position embedding is looking promising and adapts to different image/window sizes
@ -67,27 +68,29 @@ default_cfgs = {
'swin_v2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_tiny_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_small_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
url="", input_size=(3, 224, 224), crop_pct=0.9),
}
@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features=meta_hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
drop=0. # FIXME should we add stochasticity?
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
)
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
self._make_pair_wise_relative_positions()
@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module):
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?)
# Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask()
@ -393,7 +397,8 @@ class SwinTransformerBlock(nn.Module):
# cyclic shift
sh, sw = self.shift_size
if any(self.shift_size):
do_shift: bool = any(self.shift_size)
if do_shift:
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
@ -411,7 +416,7 @@ class SwinTransformerBlock(nn.Module):
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
# reverse cyclic shift
if any(self.shift_size):
if do_shift:
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
@ -432,7 +437,7 @@ class SwinTransformerBlock(nn.Module):
# NOTE post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant)
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
return x
@ -502,8 +507,8 @@ class SwinTransformerStage(nn.Module):
drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed
"""
@ -520,17 +525,23 @@ class SwinTransformerStage(nn.Module):
drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False,
) -> None:
super(SwinTransformerStage, self).__init__()
self.downscale: bool = downscale
self.grad_checkpointing: bool = grad_checkpointing
self.grad_checkpointing: bool = False
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
def _extra_norm(index):
i = index + 1
if extra_norm_period and i % extra_norm_period == 0:
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
@ -543,7 +554,7 @@ class SwinTransformerStage(nn.Module):
drop=drop,
drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False,
extra_norm=_extra_norm(index),
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
@ -605,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized. Default: False
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed. Default: False
use_deformable (bool): If true deformable block is used. Default: False
"""
def __init__(
@ -626,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0,
extra_norm_stage: bool = False,
sequential_attn: bool = False,
global_pool: str = 'avg',
weight_init='skip',
**kwargs: Any
) -> None:
super(SwinTransformerV2Cr, self).__init__()
@ -643,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module):
self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.patch_embed: nn.Module = PatchEmbed(
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
@ -664,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module):
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
grad_checkpointing=grad_checkpointing,
extra_norm_period=extra_norm_period,
extra_norm_stage=extra_norm_stage or (index + 1) == len(depths), # last stage ends w/ norm
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
@ -673,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module):
self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool
self.head: nn.Module = nn.Linear(
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes else nn.Identity()
# FIXME weight init TBD, PyTorch default init appears to be working well,
# but differs from usual ViT or Swin init.
# named_apply(init_weights, self)
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
# FIXME more experiments needed
if weight_init != 'skip':
named_apply(init_weights, self)
def update_input_size(
self,
@ -709,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module):
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^patch_embed', # stem and embed
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+).downsample', (0,)),
(r'^stages\.(\d+)\.\w+\.(\d+)', None),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore()
def get_classifier(self) -> nn.Module:
"""Method returns the classification head of the model.
Returns:
head (nn.Module): Current classification head
"""
head: nn.Module = self.head
return head
return self.head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head
@ -727,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module):
self.num_classes: int = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head: nn.Module = nn.Linear(
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
@ -747,12 +773,14 @@ class SwinTransformerV2Cr(nn.Module):
def init_weights(module: nn.Module, name: str = ''):
# FIXME WIP
# FIXME WIP determining if there's a better weight init
if isinstance(module, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val)
elif 'head' in name:
nn.init.zeros_(module.weight)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
@ -790,6 +818,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
** Experimental, may make default if results are improved. **
"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_small_384(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k"""

@ -217,7 +217,7 @@ class TNT(nn.Module):
matcher = dict(
stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
blocks=[
(r'^blocks.(\d+)', None),
(r'^blocks\.(\d+)', None),
(r'^norm', (99999,)),
]
)

@ -233,7 +233,7 @@ class TResNet(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)')
matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore

@ -327,11 +327,11 @@ class Twins(nn.Module):
matcher = dict(
stem=r'^patch_embeds.0', # stem and embed
blocks=[
(r'^(?:blocks|patch_embeds|pos_block).(\d+)', None),
(r'^(?:blocks|patch_embeds|pos_block)\.(\d+)', None),
('^norm', (99999,))
] if coarse else [
(r'^blocks.(\d+).(\d+)', None),
(r'^(?:patch_embeds|pos_block).(\d+)', (0,)),
(r'^blocks\.(\d+)\.(\d+)', None),
(r'^(?:patch_embeds|pos_block)\.(\d+)', (0,)),
(r'^norm', (99999,))
]
)

@ -136,7 +136,7 @@ class VGG(nn.Module):
@torch.jit.ignore
def group_matcher(self, coarse=False):
# this treats BN layers as separate groups for bn variants, a lot of effort to fix that
return dict(stem=r'^features.0', blocks=r'^features.(\d+)')
return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):

@ -271,7 +271,7 @@ class Visformer(nn.Module):
return dict(
stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
blocks=[
(r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None),
(r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
(r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
(r'^norm', (99999,))
]

@ -331,7 +331,7 @@ class VisionTransformer(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

@ -6,7 +6,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in:
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO
- https://arxiv.org/abs/2106.10270
NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane.
@ -359,4 +359,4 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
return model

@ -327,7 +327,7 @@ class VovNet(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
)
@torch.jit.ignore

@ -34,12 +34,20 @@ default_cfgs = dict(
xception41=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
xception65=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
crop_pct=0.94,
),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(url=''),
xception65p=_cfg(url=''),
xception41p=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
crop_pct=0.94,
),
xception65p=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
crop_pct=0.94,
),
)
@ -213,7 +221,7 @@ class XceptionAligned(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^blocks.(\d+)',
blocks=r'^blocks\.(\d+)',
)
@torch.jit.ignore

@ -412,8 +412,8 @@ class XCiT(nn.Module):
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))]
blocks=r'^blocks\.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore

@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min
)
self.noise_range = noise_range_t
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self.noise_range is not None:
if isinstance(self.noise_range, (list, tuple)):
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
if self._is_apply_noise(epoch):
self._apply_noise(epoch)
def _apply_noise(self, epoch):
g = torch.Generator()
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
noise = self._calculate_noise(epoch)
# apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler

@ -90,21 +90,30 @@ class Scheduler:
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
if apply_noise:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
lrs = [v + v * noise for v in lrs]
return lrs
return apply_noise
def _calculate_noise(self, t) -> float:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
return noise
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise

@ -89,7 +89,7 @@ parser.add_argument('--crop-pct', default=None, type=float,
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
help='Override std deviation of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',

Loading…
Cancel
Save