Update a few maxxvit comments, rename PartitionAttention -> PartitionAttenionCl for consistency with other blocks

pull/804/merge
Ross Wightman 2 years ago
parent eca6f0a25c
commit f1d2160d85

@ -651,7 +651,11 @@ class LayerScale2d(nn.Module):
class Downsample2d(nn.Module): class Downsample2d(nn.Module):
""" A downsample pooling module for Coat that handles 2d <-> 1d conversion """ A downsample pooling module supporting several maxpool and avgpool modes
* 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
* 'max2' - MaxPool2d w/ kernel_size = stride = 2
* 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
* 'avg2' - AvgPool2d w/ kernel_size = stride = 2
""" """
def __init__( def __init__(
@ -710,6 +714,11 @@ def _init_transformer(module, name, scheme=''):
class TransformerBlock2d(nn.Module): class TransformerBlock2d(nn.Module):
""" Transformer block with 2D downsampling """ Transformer block with 2D downsampling
'2D' NCHW tensor layout '2D' NCHW tensor layout
Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
This impl was faster on TPU w/ PT XLA than the 1D experiment.
""" """
def __init__( def __init__(
@ -1011,9 +1020,9 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
return rel_pos_cls return rel_pos_cls
class PartitionAttention(nn.Module): class PartitionAttentionCl(nn.Module):
""" Grid or Block partition + Attn + FFN. """ Grid or Block partition + Attn + FFN.
NxC tensor layout. NxC 'channels last' tensor layout.
""" """
def __init__( def __init__(
@ -1183,6 +1192,7 @@ def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
class PartitionAttention2d(nn.Module): class PartitionAttention2d(nn.Module):
""" Grid or Block partition + Attn + FFN """ Grid or Block partition + Attn + FFN
'2D' NCHW tensor layout. '2D' NCHW tensor layout.
""" """
@ -1245,7 +1255,7 @@ class PartitionAttention2d(nn.Module):
class MaxxVitBlock(nn.Module): class MaxxVitBlock(nn.Module):
""" """ MaxVit conv, window partition + FFN , grid partition + FFN
""" """
def __init__( def __init__(
@ -1264,7 +1274,7 @@ class MaxxVitBlock(nn.Module):
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttention partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs) self.attn_block = partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
@ -1288,7 +1298,8 @@ class MaxxVitBlock(nn.Module):
class ParallelMaxxVitBlock(nn.Module): class ParallelMaxxVitBlock(nn.Module):
""" """ MaxVit block with parallel cat(window + grid), one FF
Experimental timm block.
""" """
def __init__( def __init__(
@ -1427,7 +1438,9 @@ class Stem(nn.Module):
class MaxxVit(nn.Module): class MaxxVit(nn.Module):
""" """ CoaTNet + MaxVit base model.
Highly configurable for different block compositions, tensor layouts, pooling types.
""" """
def __init__( def __init__(

Loading…
Cancel
Save