diff --git a/timm/models/davit.py b/timm/models/davit.py index 1f0055b6..22a12470 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -17,9 +17,11 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools + from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union + import torch import torch.nn as nn import torch.nn.functional as F @@ -35,35 +37,9 @@ from .registry import register_model __all__ = ['DaViT'] -''' -class MySequential(nn.Sequential): - def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]): - for module in self: - inputs = module(*inputs) - #if type(inputs) == tuple: - # inputs = module(*inputs) - #else: - # inputs = module(inputs) - return inputs - -''' -''' -class MySequential(nn.Sequential): - @overload - def forward(self, inputs : Tensor): - for module in self._modules.values(): - inputs = module(inputs) - return inputs - - @overload - def forward(self, inputs : Tuple[Tensor, Tensor]): - for module in self._modules.values(): - inputs = module(*inputs) - return inputs - -''' class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): + super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim, @@ -130,6 +106,7 @@ class PatchEmbed(nn.Module): padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) + def forward(self, x, size: Tuple[int, int]): H, W = size dim = len(x.shape) @@ -203,6 +180,7 @@ class ChannelBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) + def forward(self, x, size: Tuple[int, int]): x = self.cpe[0](x, size) cur = self.norm1(x) @@ -327,7 +305,9 @@ class SpatialBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) + def forward(self, x, size: Tuple[int, int]): + H, W = size B, L, C = x.shape assert L == H * W, "input feature has wrong size" @@ -538,6 +518,7 @@ class DaViT(nn.Module): features.append(x) sizes.append(size) branches.append(branch_id) + block_index : int = block_index @@ -552,11 +533,14 @@ class DaViT(nn.Module): for layer_index, branch_id in enumerate(block_param): layer_index : int = layer_index branch_id : int = branch_id + if self.grad_checkpointing and not torch.jit.is_scripting(): features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) else: features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) + + # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model outs = [] for i in range(self.num_stages): @@ -565,6 +549,7 @@ class DaViT(nn.Module): H, W = sizes[i] out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) +