|
|
|
@ -274,7 +274,7 @@ class Twins(nn.Module):
|
|
|
|
|
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512),
|
|
|
|
|
num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None,
|
|
|
|
|
block_cls=Block):
|
|
|
|
|
block_cls=Block, F4=False, extra_norm=True, strides=(2, 2, 2), task='cls'):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.depths = depths
|
|
|
|
@ -290,6 +290,28 @@ class Twins(nn.Module):
|
|
|
|
|
img_size = tuple(t // patch_size for t in img_size)
|
|
|
|
|
patch_size = 2
|
|
|
|
|
|
|
|
|
|
self.task = task
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
if self.extra_norm:
|
|
|
|
|
self.norm_list = nn.ModuleList()
|
|
|
|
|
for dim in embed_dims:
|
|
|
|
|
self.norm_list.append(norm_layer(dim))
|
|
|
|
|
|
|
|
|
|
if strides != (2, 2, 2):
|
|
|
|
|
del self.patch_embeds
|
|
|
|
|
self.patch_embeds = nn.ModuleList()
|
|
|
|
|
s = 1
|
|
|
|
|
for i in range(len(depths)):
|
|
|
|
|
if i == 0:
|
|
|
|
|
self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
|
|
|
|
|
else:
|
|
|
|
|
self.patch_embeds.append(
|
|
|
|
|
PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i]))
|
|
|
|
|
s = s * strides[i-1]
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList()
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|
|
|
|
cur = 0
|
|
|
|
@ -341,24 +363,48 @@ class Twins(nn.Module):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
|
|
|
|
|
x, size = embed(x)
|
|
|
|
|
x = drop(x)
|
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
|
x = blk(x, size)
|
|
|
|
|
if j == 0:
|
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
|
if i < len(self.depths) - 1:
|
|
|
|
|
|
|
|
|
|
if self.task == 'cls':
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
|
|
|
|
|
x, size = embed(x)
|
|
|
|
|
x = drop(x)
|
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
|
x = blk(x, size)
|
|
|
|
|
if j == 0:
|
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
|
if i < len(self.depths) - 1:
|
|
|
|
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x.mean(dim=1) # GAP here
|
|
|
|
|
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
outputs = list()
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
|
|
|
|
|
x, size = embed(x)
|
|
|
|
|
x = drop(x)
|
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
|
x = blk(x, size)
|
|
|
|
|
if j == 0:
|
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
|
if self.extra_norm:
|
|
|
|
|
x = self.norm_list[i](x)
|
|
|
|
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x.mean(dim=1) # GAP here
|
|
|
|
|
outputs.append(x)
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
if self.task == 'cls':
|
|
|
|
|
return self.head(x)
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
if self.F4:
|
|
|
|
|
x = x[3:4]
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
@ -372,12 +418,15 @@ def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
|
img_size = kwargs.pop('img_size', default_img_size)
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|
if kwargs.get('task', 'cls') not in ['cls', 'seg']:
|
|
|
|
|
raise RuntimeError('twins in timm only supports "cls" and "seg" task now.')
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
Twins, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfg,
|
|
|
|
|
img_size=img_size,
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|
pretrained_strict=False,
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
@ -428,4 +477,4 @@ def twins_svt_large(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
|
|
|
|
|
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
|
|
|
|
|
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
|