|
|
@ -74,26 +74,26 @@ default_cfgs = {
|
|
|
|
# Fiddling with configs / defaults / still pretraining
|
|
|
|
# Fiddling with configs / defaults / still pretraining
|
|
|
|
'coatnet_pico_rw_224': _cfg(url=''),
|
|
|
|
'coatnet_pico_rw_224': _cfg(url=''),
|
|
|
|
'coatnet_nano_rw_224': _cfg(
|
|
|
|
'coatnet_nano_rw_224': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
|
|
|
|
crop_pct=0.9),
|
|
|
|
crop_pct=0.9),
|
|
|
|
'coatnet_0_rw_224': _cfg(
|
|
|
|
'coatnet_0_rw_224': _cfg(
|
|
|
|
url=''),
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
|
|
|
|
'coatnet_1_rw_224': _cfg(
|
|
|
|
'coatnet_1_rw_224': _cfg(
|
|
|
|
url=''
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
|
|
|
|
),
|
|
|
|
),
|
|
|
|
'coatnet_2_rw_224': _cfg(url=''),
|
|
|
|
'coatnet_2_rw_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
|
# Highly experimental configs
|
|
|
|
# Highly experimental configs
|
|
|
|
'coatnet_bn_0_rw_224': _cfg(
|
|
|
|
'coatnet_bn_0_rw_224': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
|
|
|
crop_pct=0.95),
|
|
|
|
crop_pct=0.95),
|
|
|
|
'coatnet_rmlp_nano_rw_224': _cfg(
|
|
|
|
'coatnet_rmlp_nano_rw_224': _cfg(
|
|
|
|
url='',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
|
|
|
|
crop_pct=0.9),
|
|
|
|
crop_pct=0.9),
|
|
|
|
'coatnet_rmlp_0_rw_224': _cfg(url=''),
|
|
|
|
'coatnet_rmlp_0_rw_224': _cfg(url=''),
|
|
|
|
'coatnet_rmlp_1_rw_224': _cfg(
|
|
|
|
'coatnet_rmlp_1_rw_224': _cfg(
|
|
|
|
url=''),
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
|
|
|
|
'coatnet_nano_cc_224': _cfg(url=''),
|
|
|
|
'coatnet_nano_cc_224': _cfg(url=''),
|
|
|
|
'coatnext_nano_rw_224': _cfg(url=''),
|
|
|
|
'coatnext_nano_rw_224': _cfg(url=''),
|
|
|
|
|
|
|
|
|
|
|
@ -107,10 +107,12 @@ default_cfgs = {
|
|
|
|
|
|
|
|
|
|
|
|
# Experimental configs
|
|
|
|
# Experimental configs
|
|
|
|
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_nano_rw_256': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-3e790ce3.pth',
|
|
|
|
|
|
|
|
input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_rw_224': _cfg(url=''),
|
|
|
|
'maxvit_tiny_rw_224': _cfg(url=''),
|
|
|
|
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_cm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
|
|
|
|
|
|
|
|
|
|
|
|
# Trying to be like the MaxViT paper configs
|
|
|
|
# Trying to be like the MaxViT paper configs
|
|
|
@ -131,7 +133,7 @@ class MaxxVitTransformerCfg:
|
|
|
|
attn_bias: bool = True
|
|
|
|
attn_bias: bool = True
|
|
|
|
attn_drop: float = 0.
|
|
|
|
attn_drop: float = 0.
|
|
|
|
proj_drop: float = 0.
|
|
|
|
proj_drop: float = 0.
|
|
|
|
pool_type: str = 'avg'
|
|
|
|
pool_type: str = 'avg2'
|
|
|
|
rel_pos_type: str = 'bias'
|
|
|
|
rel_pos_type: str = 'bias'
|
|
|
|
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
|
|
|
rel_pos_dim: int = 512 # for relative position types w/ MLP
|
|
|
|
window_size: Tuple[int, int] = (7, 7)
|
|
|
|
window_size: Tuple[int, int] = (7, 7)
|
|
|
@ -153,7 +155,7 @@ class MaxxVitConvCfg:
|
|
|
|
pre_norm_act: bool = False # activation after pre-norm
|
|
|
|
pre_norm_act: bool = False # activation after pre-norm
|
|
|
|
output_bias: bool = True # bias for shortcut + final 1x1 projection conv
|
|
|
|
output_bias: bool = True # bias for shortcut + final 1x1 projection conv
|
|
|
|
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
|
|
|
|
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
|
|
|
|
pool_type: str = 'avg'
|
|
|
|
pool_type: str = 'avg2'
|
|
|
|
downsample_pool_type: str = 'avg2'
|
|
|
|
downsample_pool_type: str = 'avg2'
|
|
|
|
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
|
|
|
|
attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
|
|
|
|
attn_layer: str = 'se'
|
|
|
|
attn_layer: str = 'se'
|
|
|
@ -241,7 +243,7 @@ def _rw_coat_cfg(
|
|
|
|
|
|
|
|
|
|
|
|
def _rw_max_cfg(
|
|
|
|
def _rw_max_cfg(
|
|
|
|
stride_mode='dw',
|
|
|
|
stride_mode='dw',
|
|
|
|
pool_type='avg',
|
|
|
|
pool_type='avg2',
|
|
|
|
conv_output_bias=False,
|
|
|
|
conv_output_bias=False,
|
|
|
|
conv_attn_ratio=1 / 16,
|
|
|
|
conv_attn_ratio=1 / 16,
|
|
|
|
conv_norm_layer='',
|
|
|
|
conv_norm_layer='',
|
|
|
@ -325,7 +327,6 @@ model_cfgs = dict(
|
|
|
|
depths=(2, 3, 5, 2),
|
|
|
|
depths=(2, 3, 5, 2),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg( # using newer max defaults here
|
|
|
|
**_rw_max_cfg( # using newer max defaults here
|
|
|
|
pool_type='avg2',
|
|
|
|
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
),
|
|
|
|
),
|
|
|
@ -336,7 +337,6 @@ model_cfgs = dict(
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg( # using newer max defaults here
|
|
|
|
**_rw_max_cfg( # using newer max defaults here
|
|
|
|
stride_mode='pool',
|
|
|
|
stride_mode='pool',
|
|
|
|
pool_type='avg2',
|
|
|
|
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
),
|
|
|
|
),
|
|
|
@ -384,7 +384,6 @@ model_cfgs = dict(
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(
|
|
|
|
**_rw_max_cfg(
|
|
|
|
pool_type='avg2',
|
|
|
|
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_output_bias=True,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
conv_attn_ratio=0.25,
|
|
|
|
rel_pos_type='mlp',
|
|
|
|
rel_pos_type='mlp',
|
|
|
@ -487,10 +486,10 @@ model_cfgs = dict(
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
maxvit_tiny_cm_256=MaxxVitCfg(
|
|
|
|
maxvit_tiny_pm_256=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
depths=(2, 2, 5, 2),
|
|
|
|
block_type=('CM',) * 4,
|
|
|
|
block_type=('PM',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
**_rw_max_cfg(window_size=8),
|
|
|
|
),
|
|
|
|
),
|
|
|
@ -663,13 +662,15 @@ class Downsample2d(nn.Module):
|
|
|
|
bias: bool = True,
|
|
|
|
bias: bool = True,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert pool_type in ('max', 'avg', 'avg2')
|
|
|
|
assert pool_type in ('max', 'max2', 'avg', 'avg2')
|
|
|
|
if pool_type == 'max':
|
|
|
|
if pool_type == 'max':
|
|
|
|
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
|
|
|
|
|
|
elif pool_type == 'max2':
|
|
|
|
|
|
|
|
self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2
|
|
|
|
elif pool_type == 'avg':
|
|
|
|
elif pool_type == 'avg':
|
|
|
|
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
|
|
|
|
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.pool = nn.AvgPool2d(2)
|
|
|
|
self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2
|
|
|
|
|
|
|
|
|
|
|
|
if dim != dim_out:
|
|
|
|
if dim != dim_out:
|
|
|
|
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
|
|
|
|
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
|
|
|
@ -1073,7 +1074,7 @@ class PartitionAttention(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CombinedPartitionAttention(nn.Module):
|
|
|
|
class ParallelPartitionAttention(nn.Module):
|
|
|
|
""" Experimental. Grid and Block partition + single FFN
|
|
|
|
""" Experimental. Grid and Block partition + single FFN
|
|
|
|
NxC tensor layout.
|
|
|
|
NxC tensor layout.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -1286,7 +1287,7 @@ class MaxxVitBlock(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CombinedMaxxVitBlock(nn.Module):
|
|
|
|
class ParallelMaxxVitBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
@ -1309,7 +1310,7 @@ class CombinedMaxxVitBlock(nn.Module):
|
|
|
|
self.conv = nn.Sequential(*convs)
|
|
|
|
self.conv = nn.Sequential(*convs)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
|
|
|
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
|
|
|
|
self.attn = CombinedPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
|
|
|
|
self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self, scheme=''):
|
|
|
|
def init_weights(self, scheme=''):
|
|
|
|
named_apply(partial(_init_transformer, scheme=scheme), self.attn)
|
|
|
|
named_apply(partial(_init_transformer, scheme=scheme), self.attn)
|
|
|
@ -1343,7 +1344,7 @@ class MaxxVitStage(nn.Module):
|
|
|
|
blocks = []
|
|
|
|
blocks = []
|
|
|
|
for i, t in enumerate(block_types):
|
|
|
|
for i, t in enumerate(block_types):
|
|
|
|
block_stride = stride if i == 0 else 1
|
|
|
|
block_stride = stride if i == 0 else 1
|
|
|
|
assert t in ('C', 'T', 'M', 'CM')
|
|
|
|
assert t in ('C', 'T', 'M', 'PM')
|
|
|
|
if t == 'C':
|
|
|
|
if t == 'C':
|
|
|
|
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
|
|
|
|
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
|
|
|
|
blocks += [conv_cls(
|
|
|
|
blocks += [conv_cls(
|
|
|
@ -1372,8 +1373,8 @@ class MaxxVitStage(nn.Module):
|
|
|
|
transformer_cfg=transformer_cfg,
|
|
|
|
transformer_cfg=transformer_cfg,
|
|
|
|
drop_path=drop_path[i],
|
|
|
|
drop_path=drop_path[i],
|
|
|
|
)]
|
|
|
|
)]
|
|
|
|
elif t == 'CM':
|
|
|
|
elif t == 'PM':
|
|
|
|
blocks += [CombinedMaxxVitBlock(
|
|
|
|
blocks += [ParallelMaxxVitBlock(
|
|
|
|
in_chs,
|
|
|
|
in_chs,
|
|
|
|
out_chs,
|
|
|
|
out_chs,
|
|
|
|
stride=block_stride,
|
|
|
|
stride=block_stride,
|
|
|
@ -1415,7 +1416,6 @@ class Stem(nn.Module):
|
|
|
|
self.norm1 = norm_act_layer(out_chs[0])
|
|
|
|
self.norm1 = norm_act_layer(out_chs[0])
|
|
|
|
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)
|
|
|
|
self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
|
|
|
def init_weights(self, scheme=''):
|
|
|
|
def init_weights(self, scheme=''):
|
|
|
|
named_apply(partial(_init_conv, scheme=scheme), self)
|
|
|
|
named_apply(partial(_init_conv, scheme=scheme), self)
|
|
|
|
|
|
|
|
|
|
|
@ -1659,8 +1659,8 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def maxvit_tiny_cm_256(pretrained=False, **kwargs):
|
|
|
|
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
|
|
|
|
return _create_maxxvit('maxvit_tiny_cm_256', pretrained=pretrained, **kwargs)
|
|
|
|
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|