|
|
@ -651,7 +651,11 @@ class LayerScale2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Downsample2d(nn.Module):
|
|
|
|
class Downsample2d(nn.Module):
|
|
|
|
""" A downsample pooling module for Coat that handles 2d <-> 1d conversion
|
|
|
|
""" A downsample pooling module supporting several maxpool and avgpool modes
|
|
|
|
|
|
|
|
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
|
|
|
|
|
|
|
|
* 'max2' - MaxPool2d w/ kernel_size = stride = 2
|
|
|
|
|
|
|
|
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
|
|
|
|
|
|
|
|
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -710,6 +714,11 @@ def _init_transformer(module, name, scheme=''):
|
|
|
|
class TransformerBlock2d(nn.Module):
|
|
|
|
class TransformerBlock2d(nn.Module):
|
|
|
|
""" Transformer block with 2D downsampling
|
|
|
|
""" Transformer block with 2D downsampling
|
|
|
|
'2D' NCHW tensor layout
|
|
|
|
'2D' NCHW tensor layout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
|
|
|
|
|
|
|
|
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This impl was faster on TPU w/ PT XLA than the 1D experiment.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -1011,9 +1020,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
|
|
|
|
return rel_pos_cls
|
|
|
|
return rel_pos_cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PartitionAttention(nn.Module):
|
|
|
|
class PartitionAttentionCl(nn.Module):
|
|
|
|
""" Grid or Block partition + Attn + FFN.
|
|
|
|
""" Grid or Block partition + Attn + FFN.
|
|
|
|
NxC tensor layout.
|
|
|
|
NxC 'channels last' tensor layout.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -1183,6 +1192,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
|
|
|
|
|
|
|
|
|
|
|
|
class PartitionAttention2d(nn.Module):
|
|
|
|
class PartitionAttention2d(nn.Module):
|
|
|
|
""" Grid or Block partition + Attn + FFN
|
|
|
|
""" Grid or Block partition + Attn + FFN
|
|
|
|
|
|
|
|
|
|
|
|
'2D' NCHW tensor layout.
|
|
|
|
'2D' NCHW tensor layout.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
@ -1245,7 +1255,7 @@ class PartitionAttention2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaxxVitBlock(nn.Module):
|
|
|
|
class MaxxVitBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
""" MaxVit conv, window partition + FFN , grid partition + FFN
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -1264,7 +1274,7 @@ class MaxxVitBlock(nn.Module):
|
|
|
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
|
|
|
|
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
|
|
|
|
|
|
|
|
|
|
|
|
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
|
|
|
|
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
|
|
|
|
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention
|
|
|
|
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
|
|
|
|
self.nchw_attn = use_nchw_attn
|
|
|
|
self.nchw_attn = use_nchw_attn
|
|
|
|
self.attn_block = partition_layer(**attn_kwargs)
|
|
|
|
self.attn_block = partition_layer(**attn_kwargs)
|
|
|
|
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
|
|
|
|
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
|
|
|
@ -1288,7 +1298,8 @@ class MaxxVitBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelMaxxVitBlock(nn.Module):
|
|
|
|
class ParallelMaxxVitBlock(nn.Module):
|
|
|
|
"""
|
|
|
|
""" MaxVit block with parallel cat(window + grid), one FF
|
|
|
|
|
|
|
|
Experimental timm block.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -1427,7 +1438,9 @@ class Stem(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaxxVit(nn.Module):
|
|
|
|
class MaxxVit(nn.Module):
|
|
|
|
"""
|
|
|
|
""" CoaTNet + MaxVit base model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Highly configurable for different block compositions, tensor layouts, pooling types.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|