diff --git a/timm/models/davit.py b/timm/models/davit.py index b19a89f8..8e84f7ed 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -349,12 +349,16 @@ class SpatialBlock(nn.Module): class DaViT(nn.Module): - r""" Dual Attention Transformer + r""" DaViT + A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 + Args: - patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 - embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256) - num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16) + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) + patch_size (int | tuple(int)): Patch size. Default: 4 + embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) + num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True @@ -380,11 +384,10 @@ class DaViT(nn.Module): cpe_act=False, drop_rate=0., attn_drop_rate=0., - img_size=224, num_classes=1000, global_pool='avg', - #features_only = False - **kwargs): + **kwargs + ): super().__init__() architecture = [[index] * item for index, item in enumerate(depths)] @@ -399,7 +402,6 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self._features_only = kwargs.get('features_only', False) self.feature_info = [] self.patch_embeds = nn.ModuleList([ @@ -409,12 +411,11 @@ class DaViT(nn.Module): overlapped=overlapped_patch) for i in range(self.num_stages)]) - #main_blocks = [] - self.main_blocks = nn.ModuleList() - for block_id, block_param in enumerate(self.architecture): - layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id]))) + self.stages = nn.ModuleList() + for stage_id, stage_param in enumerate(self.architecture): + layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - block = nn.ModuleList([ + stage = nn.ModuleList([ nn.ModuleList([ ChannelBlock( dim=self.embed_dims[item], @@ -438,74 +439,18 @@ class DaViT(nn.Module): window_size=window_size, ) if attention_type == 'spatial' else None for attention_id, attention_type in enumerate(attention_types)] - ) for layer_id, item in enumerate(block_param) + ) for layer_id, item in enumerate(stage_param) ]) - self.main_blocks.add_module(f'block_{block_id}', block) + self.main_blocks.add_module(f'stage_{stage_id}', stage) - self.feature_info += [dict(num_chs=self.embed_dims[block_id], reduction = 2, module=f'block_{block_id}')] - #self.main_blocks = nn.ModuleList(main_blocks) - - ''' - # layer norms for pyramid feature extraction - # - # TODO implement pyramid feature extraction - # - # davit should be a good transformer candidate, since the only official implementation - # is for segmentation and detection - for i_layer in range(self.num_stages): - layer = norm_layer(self.embed_dims[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - ''' - self.norms = norm_layer(self.num_features) + self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction = 2, module=f'stage_{stage_id}')] + + + self.norm = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - #self._update_forward_fn() - - #self.forward = self._get_forward_fn() - ''' - if self._features_only == True: - self.forward = self.forward_features_full - else: - self.forward = self.forward_classification - ''' - - ''' - def _get_forward_fn(self): - if self._features_only == True: - return self.forward_features_full - else: - return self.forward_classification - ''' - ''' - @torch.jit.ignore - def _get_forward_fn(self): - if self._features_only == True: - return self.forward_features_full - else: - return self.forward_classification - ''' - - @torch.jit.ignore - def _update_forward_fn(self): - if self._features_only == True: - self.forward = self.forward_pyramid_features - else: - self.forward = self.forward_classification - - @property - def features_only(self): - return self._features_only - - @features_only.setter - def features_only(self, new_value : bool): - self._features_only = new_value - #self.forward = self._get_forward_fn() - self._update_forward_fn() - - def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -515,9 +460,7 @@ class DaViT(nn.Module): elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - - - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @@ -534,15 +477,11 @@ class DaViT(nn.Module): def forward_network(self, x): - #x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] sizes = [size] - #branches = [0] - - - for patch_layer, stage in zip(self.patch_embeds, self.main_blocks): + for patch_layer, stage in enumerate(zip(self.patch_embeds, self.stages)): features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) for _, block in enumerate(stage): for _, layer in enumerate(block): @@ -550,63 +489,15 @@ class DaViT(nn.Module): features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) else: features[-1], sizes[-1] = layer(features[-1], sizes[-1]) - - features.append(features[-1]) - sizes.append(sizes[-1]) + # don't append outputs of last stage, since they are already there + if(len(features) < self.num_stages): + features.append(features[-1]) + sizes.append(sizes[-1]) - - - ''' - for block_index, block_param in enumerate(self.architecture): - - branch_ids = sorted(set(block_param)) - for branch_id in branch_ids: - if branch_id not in branches: - x, size = self.patch_embeds[branch_id](features[-1], sizes[-1]) - features.append(x) - sizes.append(size) - branches.append(branch_id) - - - - block_index : int = block_index - - if block_index not in branches: - x, size = self.patch_embeds[block_index](features[-1], sizes[-1]) - features.append(x) - sizes.append(size) - branches.append(branch_id) - - - for layer_index, branch_id in enumerate(block_param): - layer_index : int = layer_index - branch_id : int = branch_id - - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) - else: - features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) - - - - # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model - outs = [] - for i in range(self.num_stages): - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(features[i]) - H, W = sizes[i] - out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() - outs.append(out) - - - - ''' - - # non-normalized pyramid features + corresponding sizes - return features[:-1], sizes[:-1] + return features, sizes def forward_pyramid_features(self, x): x, sizes = self.forward_network(x) @@ -620,22 +511,19 @@ class DaViT(nn.Module): def forward_features(self, x): x, sizes = self.forward_network(x) # take final feature and norm - x = self.norms(x[-1]) + x = self.norm(x[-1]) H, W = sizes[-1] x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() - #print(x.shape) return x def forward_head(self, x, pre_logits: bool = False): return self.head(x, pre_logits=pre_logits) - def forward_classification(self, x): + def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x - - def forward(self, x): - return x + def checkpoint_filter_fn(state_dict, model): @@ -645,11 +533,10 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - + out_dict = {} - import re for k, v in state_dict.items(): - + k = k.replace('main_blocks.', 'main_blocks.stage_') k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict @@ -657,10 +544,15 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - out_indices = (i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) - feature_cfg = {'out_indices': out_indices} - model = build_model_with_cfg(DaViT, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=feature_cfg, **kwargs) + default_out_indices = (i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + model = build_model_with_cfg( + DaViT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) return model