|
|
|
@ -17,9 +17,11 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
@ -35,35 +37,9 @@ from .registry import register_model
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
|
|
|
|
|
for module in self:
|
|
|
|
|
inputs = module(*inputs)
|
|
|
|
|
#if type(inputs) == tuple:
|
|
|
|
|
# inputs = module(*inputs)
|
|
|
|
|
#else:
|
|
|
|
|
# inputs = module(inputs)
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
'''
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
|
@overload
|
|
|
|
|
def forward(self, inputs : Tensor):
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
inputs = module(inputs)
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tensor]):
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
inputs = module(*inputs)
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
|
self.proj = nn.Conv2d(dim,
|
|
|
|
|
dim,
|
|
|
|
@ -130,6 +106,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
H, W = size
|
|
|
|
|
dim = len(x.shape)
|
|
|
|
@ -203,6 +180,7 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
x = self.cpe[0](x, size)
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
@ -327,7 +305,9 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
H, W = size
|
|
|
|
|
B, L, C = x.shape
|
|
|
|
|
assert L == H * W, "input feature has wrong size"
|
|
|
|
@ -538,6 +518,7 @@ class DaViT(nn.Module):
|
|
|
|
|
features.append(x)
|
|
|
|
|
sizes.append(size)
|
|
|
|
|
branches.append(branch_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_index : int = block_index
|
|
|
|
@ -552,11 +533,14 @@ class DaViT(nn.Module):
|
|
|
|
|
for layer_index, branch_id in enumerate(block_param):
|
|
|
|
|
layer_index : int = layer_index
|
|
|
|
|
branch_id : int = branch_id
|
|
|
|
|
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
|
|
|
|
|
else:
|
|
|
|
|
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
|
|
|
|
|
outs = []
|
|
|
|
|
for i in range(self.num_stages):
|
|
|
|
@ -565,6 +549,7 @@ class DaViT(nn.Module):
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
outs.append(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|