diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 2475e809..de2c9fb8 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -651,7 +651,11 @@ class LayerScale2d(nn.Module): class Downsample2d(nn.Module): - """ A downsample pooling module for Coat that handles 2d <-> 1d conversion + """ A downsample pooling module supporting several maxpool and avgpool modes + * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 + * 'max2' - MaxPool2d w/ kernel_size = stride = 2 + * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 + * 'avg2' - AvgPool2d w/ kernel_size = stride = 2 """ def __init__( @@ -710,6 +714,11 @@ def _init_transformer(module, name, scheme=''): class TransformerBlock2d(nn.Module): """ Transformer block with 2D downsampling '2D' NCHW tensor layout + + Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW + for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. + + This impl was faster on TPU w/ PT XLA than the 1D experiment. """ def __init__( @@ -1011,9 +1020,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): return rel_pos_cls -class PartitionAttention(nn.Module): +class PartitionAttentionCl(nn.Module): """ Grid or Block partition + Attn + FFN. - NxC tensor layout. + NxC 'channels last' tensor layout. """ def __init__( @@ -1183,6 +1192,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]): class PartitionAttention2d(nn.Module): """ Grid or Block partition + Attn + FFN + '2D' NCHW tensor layout. """ @@ -1245,7 +1255,7 @@ class PartitionAttention2d(nn.Module): class MaxxVitBlock(nn.Module): - """ + """ MaxVit conv, window partition + FFN , grid partition + FFN """ def __init__( @@ -1264,7 +1274,7 @@ class MaxxVitBlock(nn.Module): self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention + partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl self.nchw_attn = use_nchw_attn self.attn_block = partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) @@ -1288,7 +1298,8 @@ class MaxxVitBlock(nn.Module): class ParallelMaxxVitBlock(nn.Module): - """ + """ MaxVit block with parallel cat(window + grid), one FF + Experimental timm block. """ def __init__( @@ -1427,7 +1438,9 @@ class Stem(nn.Module): class MaxxVit(nn.Module): - """ + """ CoaTNet + MaxVit base model. + + Highly configurable for different block compositions, tensor layouts, pooling types. """ def __init__(