Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8408551195
commit a828ccaf88

@ -35,51 +35,6 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
# modified nn.Sequential that includes a size tuple in the forward function
# FIXME doesn't work with torchscript/JIT
# Module 'SequentialWithSize' has no attribute '_modules'
class SequentialWithSize(nn.Sequential):
def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values():
x, size = module(x, size)
return x, size
class ConvPosEncOld(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim,
dim,
to_2tuple(k),
to_2tuple(1),
to_2tuple(k // 2),
groups=dim)
self.normtype = normtype
self.norm = nn.Identity()
if self.normtype == 'batch':
self.norm = nn.BatchNorm2d(dim)
elif self.normtype == 'layer':
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x : Tensor, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat)
if self.normtype == 'batch':
feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer':
feat = self.norm(feat.flatten(2).transpose(1, 2))
else:
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat)
return x
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -113,69 +68,7 @@ class ConvPosEnc(nn.Module):
return x
# reason: dim in control sequence
# FIXME reimplement to allow tracing
@register_notrace_module
class PatchEmbedOld(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=96,
overlapped=False):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
if patch_size[0] == 4:
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=(7, 7),
stride=patch_size,
padding=(3, 3))
self.norm = nn.LayerNorm(embed_dim)
if patch_size[0] == 2:
kernel = 3 if overlapped else 2
pad = 1 if overlapped else 0
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=to_2tuple(kernel),
stride=patch_size,
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = x.dim()
if dim == 3:
B, HW, C = x.shape
x = self.norm(x)
x = x.reshape(B,
H,
W,
C).permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x)
newsize = (x.size(2), x.size(3))
x = x.flatten(2).transpose(1, 2)
if dim == 4:
x = self.norm(x)
return x, newsize
@register_notrace_module
class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
@ -255,70 +148,6 @@ class ChannelAttention(nn.Module):
x = self.proj(x)
return x
class ChannelAttentionNew(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x : Tensor):
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = x.transpose(1, 2).view(B, C, H, W)
return x
class ChannelBlockOld(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
ffn=True, cpe_act=False):
super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe1(x, size)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path(cur)
x = self.cpe2(x, size)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
class ChannelBlock(nn.Module):
@ -788,34 +617,7 @@ class DaViT(nn.Module):
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
'''
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
#sizes = [size]
for stage in self.stages:
features[-1] = stage(features[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
# non-normalized pyramid features + corresponding sizes
return features
def forward_pyramid_features(self, x) -> List[Tensor]:
x = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return x
'''
def forward_features(self, x):
x = self.stages(x)
# take final feature and norm
@ -834,16 +636,7 @@ class DaViT(nn.Module):
def forward(self, x):
return self.forward_classifier(x)
'''
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
'''
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict:

Loading…
Cancel
Save