diff --git a/timm/models/davit.py b/timm/models/davit.py index 8291ed84..0ccadd79 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -35,51 +35,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] -# modified nn.Sequential that includes a size tuple in the forward function -# FIXME doesn't work with torchscript/JIT -# Module 'SequentialWithSize' has no attribute '_modules' -class SequentialWithSize(nn.Sequential): - def forward(self, x : Tensor, size: Tuple[int, int]): - for module in self._modules.values(): - x, size = module(x, size) - return x, size - - -class ConvPosEncOld(nn.Module): - def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): - - super(ConvPosEnc, self).__init__() - self.proj = nn.Conv2d(dim, - dim, - to_2tuple(k), - to_2tuple(1), - to_2tuple(k // 2), - groups=dim) - self.normtype = normtype - self.norm = nn.Identity() - if self.normtype == 'batch': - self.norm = nn.BatchNorm2d(dim) - elif self.normtype == 'layer': - self.norm = nn.LayerNorm(dim) - self.activation = nn.GELU() if act else nn.Identity() - - def forward(self, x : Tensor, size: Tuple[int, int]): - B, N, C = x.shape - H, W = size - - feat = x.transpose(1, 2).view(B, C, H, W) - feat = self.proj(feat) - if self.normtype == 'batch': - feat = self.norm(feat).flatten(2).transpose(1, 2) - elif self.normtype == 'layer': - feat = self.norm(feat.flatten(2).transpose(1, 2)) - else: - feat = feat.flatten(2).transpose(1, 2) - x = x + self.activation(feat) - return x - - - class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -113,69 +68,7 @@ class ConvPosEnc(nn.Module): return x -# reason: dim in control sequence -# FIXME reimplement to allow tracing -@register_notrace_module -class PatchEmbedOld(nn.Module): - """ Size-agnostic implementation of 2D image to patch embedding, - allowing input size to be adjusted during model forward operation - """ - - def __init__( - self, - patch_size=16, - in_chans=3, - embed_dim=96, - overlapped=False): - super().__init__() - patch_size = to_2tuple(patch_size) - self.patch_size = patch_size - - if patch_size[0] == 4: - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=(7, 7), - stride=patch_size, - padding=(3, 3)) - self.norm = nn.LayerNorm(embed_dim) - if patch_size[0] == 2: - kernel = 3 if overlapped else 2 - pad = 1 if overlapped else 0 - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=to_2tuple(kernel), - stride=patch_size, - padding=to_2tuple(pad)) - self.norm = nn.LayerNorm(in_chans) - - - def forward(self, x : Tensor, size: Tuple[int, int]): - H, W = size - dim = x.dim() - if dim == 3: - B, HW, C = x.shape - x = self.norm(x) - x = x.reshape(B, - H, - W, - C).permute(0, 3, 1, 2).contiguous() - - B, C, H, W = x.shape - if W % self.patch_size[1] != 0: - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) - if H % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) - - x = self.proj(x) - newsize = (x.size(2), x.size(3)) - x = x.flatten(2).transpose(1, 2) - if dim == 4: - x = self.norm(x) - return x, newsize -@register_notrace_module class PatchEmbed(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation @@ -255,70 +148,6 @@ class ChannelAttention(nn.Module): x = self.proj(x) return x -class ChannelAttentionNew(nn.Module): - - def __init__(self, dim, num_heads=8, qkv_bias=False): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x : Tensor): - B, C, H, W = x.shape - x = x.flatten(2).transpose(1, 2) - B, N, C = x.shape - - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - k = k * self.scale - attention = k.transpose(-1, -2) @ v - attention = attention.softmax(dim=-1) - x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = x.transpose(1, 2).view(B, C, H, W) - return x - - -class ChannelBlockOld(nn.Module): - - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn=True, cpe_act=False): - super().__init__() - - self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) - self.ffn = ffn - self.norm1 = norm_layer(dim) - self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) - - if self.ffn: - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer) - - - def forward(self, x : Tensor, size: Tuple[int, int]): - - x = self.cpe1(x, size) - - cur = self.norm1(x) - cur = self.attn(cur) - x = x + self.drop_path(cur) - - x = self.cpe2(x, size) - if self.ffn: - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x, size class ChannelBlock(nn.Module): @@ -788,34 +617,7 @@ class DaViT(nn.Module): if global_pool is None: global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - ''' - def forward_network(self, x : Tensor): - size: Tuple[int, int] = (x.size(2), x.size(3)) - features = [x] - #sizes = [size] - - for stage in self.stages: - features[-1] = stage(features[-1]) - - # don't append outputs of last stage, since they are already there - if(len(features) < self.num_stages): - features.append(features[-1]) - - - # non-normalized pyramid features + corresponding sizes - return features - - def forward_pyramid_features(self, x) -> List[Tensor]: - x = self.forward_network(x) - - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return x - ''' def forward_features(self, x): x = self.stages(x) # take final feature and norm @@ -834,16 +636,7 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) -''' -class DaViTFeatures(DaViT): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) - - def forward(self, x) -> List[Tensor]: - return self.forward_pyramid_features(x) -''' def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ if 'head.norm.weight' in state_dict: