Make k=stride=2 ('avg2') pooling default for coatnet/maxvit. Add weight links. Rename 'combined' partition to 'parallel'.

pull/1415/head v0.1-weights-maxx
Ross Wightman 2 years ago
parent 837c68263b
commit b2e8426fca

@ -74,26 +74,26 @@ default_cfgs = {
# Fiddling with configs / defaults / still pretraining # Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg( 'coatnet_nano_rw_224': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_0_rw_224': _cfg( 'coatnet_0_rw_224': _cfg(
url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg( 'coatnet_1_rw_224': _cfg(
url='' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
), ),
'coatnet_2_rw_224': _cfg(url=''), 'coatnet_2_rw_224': _cfg(url=''),
# Highly experimental configs # Highly experimental configs
'coatnet_bn_0_rw_224': _cfg( 'coatnet_bn_0_rw_224': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95), crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg( 'coatnet_rmlp_nano_rw_224': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg( 'coatnet_rmlp_1_rw_224': _cfg(
url=''), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_nano_cc_224': _cfg(url=''), 'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''),
@ -107,10 +107,12 @@ default_cfgs = {
# Experimental configs # Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_224': _cfg(url=''),
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
@ -131,7 +133,7 @@ class MaxxVitTransformerCfg:
attn_bias: bool = True attn_bias: bool = True
attn_drop: float = 0. attn_drop: float = 0.
proj_drop: float = 0. proj_drop: float = 0.
pool_type: str = 'avg' pool_type: str = 'avg2'
rel_pos_type: str = 'bias' rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP rel_pos_dim: int = 512 # for relative position types w/ MLP
window_size: Tuple[int, int] = (7, 7) window_size: Tuple[int, int] = (7, 7)
@ -153,7 +155,7 @@ class MaxxVitConvCfg:
pre_norm_act: bool = False # activation after pre-norm pre_norm_act: bool = False # activation after pre-norm
output_bias: bool = True # bias for shortcut + final 1x1 projection conv output_bias: bool = True # bias for shortcut + final 1x1 projection conv
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
pool_type: str = 'avg' pool_type: str = 'avg2'
downsample_pool_type: str = 'avg2' downsample_pool_type: str = 'avg2'
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
attn_layer: str = 'se' attn_layer: str = 'se'
@ -241,7 +243,7 @@ def _rw_coat_cfg(
def _rw_max_cfg( def _rw_max_cfg(
stride_mode='dw', stride_mode='dw',
pool_type='avg', pool_type='avg2',
conv_output_bias=False, conv_output_bias=False,
conv_attn_ratio=1 / 16, conv_attn_ratio=1 / 16,
conv_norm_layer='', conv_norm_layer='',
@ -325,7 +327,6 @@ model_cfgs = dict(
depths=(2, 3, 5, 2), depths=(2, 3, 5, 2),
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here **_rw_max_cfg( # using newer max defaults here
pool_type='avg2',
conv_output_bias=True, conv_output_bias=True,
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
@ -336,7 +337,6 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg( # using newer max defaults here **_rw_max_cfg( # using newer max defaults here
stride_mode='pool', stride_mode='pool',
pool_type='avg2',
conv_output_bias=True, conv_output_bias=True,
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
@ -384,7 +384,6 @@ model_cfgs = dict(
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg( **_rw_max_cfg(
pool_type='avg2',
conv_output_bias=True, conv_output_bias=True,
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
rel_pos_type='mlp', rel_pos_type='mlp',
@ -487,10 +486,10 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(window_size=8),
), ),
maxvit_tiny_cm_256=MaxxVitCfg( maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('CM',) * 4, block_type=('PM',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(window_size=8), **_rw_max_cfg(window_size=8),
), ),
@ -663,13 +662,15 @@ class Downsample2d(nn.Module):
bias: bool = True, bias: bool = True,
): ):
super().__init__() super().__init__()
assert pool_type in ('max', 'avg', 'avg2') assert pool_type in ('max', 'max2', 'avg', 'avg2')
if pool_type == 'max': if pool_type == 'max':
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
elif pool_type == 'max2':
self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2
elif pool_type == 'avg': elif pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
else: else:
self.pool = nn.AvgPool2d(2) self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2
if dim != dim_out: if dim != dim_out:
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
@ -1073,7 +1074,7 @@ class PartitionAttention(nn.Module):
return x return x
class CombinedPartitionAttention(nn.Module): class ParallelPartitionAttention(nn.Module):
""" Experimental. Grid and Block partition + single FFN """ Experimental. Grid and Block partition + single FFN
NxC tensor layout. NxC tensor layout.
""" """
@ -1286,7 +1287,7 @@ class MaxxVitBlock(nn.Module):
return x return x
class CombinedMaxxVitBlock(nn.Module): class ParallelMaxxVitBlock(nn.Module):
""" """
""" """
@ -1309,7 +1310,7 @@ class CombinedMaxxVitBlock(nn.Module):
self.conv = nn.Sequential(*convs) self.conv = nn.Sequential(*convs)
else: else:
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
def init_weights(self, scheme=''): def init_weights(self, scheme=''):
named_apply(partial(_init_transformer, scheme=scheme), self.attn) named_apply(partial(_init_transformer, scheme=scheme), self.attn)
@ -1343,7 +1344,7 @@ class MaxxVitStage(nn.Module):
blocks = [] blocks = []
for i, t in enumerate(block_types): for i, t in enumerate(block_types):
block_stride = stride if i == 0 else 1 block_stride = stride if i == 0 else 1
assert t in ('C', 'T', 'M', 'CM') assert t in ('C', 'T', 'M', 'PM')
if t == 'C': if t == 'C':
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
blocks += [conv_cls( blocks += [conv_cls(
@ -1372,8 +1373,8 @@ class MaxxVitStage(nn.Module):
transformer_cfg=transformer_cfg, transformer_cfg=transformer_cfg,
drop_path=drop_path[i], drop_path=drop_path[i],
)] )]
elif t == 'CM': elif t == 'PM':
blocks += [CombinedMaxxVitBlock( blocks += [ParallelMaxxVitBlock(
in_chs, in_chs,
out_chs, out_chs,
stride=block_stride, stride=block_stride,
@ -1415,7 +1416,6 @@ class Stem(nn.Module):
self.norm1 = norm_act_layer(out_chs[0]) self.norm1 = norm_act_layer(out_chs[0])
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)
@torch.jit.ignore
def init_weights(self, scheme=''): def init_weights(self, scheme=''):
named_apply(partial(_init_conv, scheme=scheme), self) named_apply(partial(_init_conv, scheme=scheme), self)
@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
@register_model @register_model
def maxvit_tiny_cm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@register_model @register_model

Loading…
Cancel
Save