starting point

pull/1630/head
Fredo Guan 3 years ago
parent 3a0f3cec10
commit fb074f89ba

@ -33,16 +33,6 @@ from .registry import register_model
__all__ = ['DaViT'] __all__ = ['DaViT']
class MySequential(nn.Sequential):
def forward(self, inputs : Tensor, size : Tuple[int, int]):
for module in self:
output = module(inputs, size)
inputs : Tensor = output[0]
size : Tuple[int, int] = output[1]
return inputs, size
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -113,7 +103,7 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(in_chans) self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
H, W = size H, W = size
dim = len(x.shape) dim = len(x.shape)
if dim == 3: if dim == 3:
@ -187,7 +177,7 @@ class ChannelBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
x = self.cpe[0](x, size) x = self.cpe[0](x, size)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.attn(cur) cur = self.attn(cur)
@ -312,7 +302,7 @@ class SpatialBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
H, W = size H, W = size
B, L, C = x.shape B, L, C = x.shape
@ -421,8 +411,8 @@ class DaViT(nn.Module):
for stage_id, stage_param in enumerate(self.architecture): for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = MySequential(*[ stage = nn.ModuleList([
MySequential(*[ nn.ModuleList([
ChannelBlock( ChannelBlock(
dim=self.embed_dims[item], dim=self.embed_dims[item],
num_heads=self.num_heads[item], num_heads=self.num_heads[item],
@ -453,7 +443,7 @@ class DaViT(nn.Module):
self.feature_info += [dict( self.feature_info += [dict(
num_chs=self.embed_dims[stage_id], num_chs=self.embed_dims[stage_id],
reduction = 2, reduction = 2,
module=f'stages.stage_{stage_id}')]#.{depths[stage_id] - 1}.{len(attention_types) - 1}.mlp.drop2')] module=f'stages.stage_{stage_id}.{depths[stage_id] - 1}.{len(attention_types) - 1}.mlp')]
self.norms = norm_layer(self.num_features) self.norms = norm_layer(self.num_features)
@ -492,8 +482,8 @@ class DaViT(nn.Module):
for patch_layer, stage in zip(self.patch_embeds, self.stages): for patch_layer, stage in zip(self.patch_embeds, self.stages):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for block in stage: for _, block in enumerate(stage):
for layer in block: for _, layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else: else:

Loading…
Cancel
Save