update twins.py to support segmentation task

pull/659/head
李鑫杰 4 years ago
parent 23c18a33e4
commit 43180b1341

@ -274,7 +274,7 @@ class Twins(nn.Module):
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512),
num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None,
block_cls=Block):
block_cls=Block, F4=False, extra_norm=True, strides=(2, 2, 2), task='cls'):
super().__init__()
self.num_classes = num_classes
self.depths = depths
@ -290,6 +290,28 @@ class Twins(nn.Module):
img_size = tuple(t // patch_size for t in img_size)
patch_size = 2
self.task = task
if self.task == 'seg':
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
if self.extra_norm:
self.norm_list = nn.ModuleList()
for dim in embed_dims:
self.norm_list.append(norm_layer(dim))
if strides != (2, 2, 2):
del self.patch_embeds
self.patch_embeds = nn.ModuleList()
s = 1
for i in range(len(depths)):
if i == 0:
self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i]))
else:
self.patch_embeds.append(
PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i]))
s = s * strides[i-1]
self.blocks = nn.ModuleList()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
@ -341,24 +363,48 @@ class Twins(nn.Module):
m.bias.data.zero_()
def forward_features(self, x):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if i < len(self.depths) - 1:
if self.task == 'cls':
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if i < len(self.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
x = self.norm(x)
return x.mean(dim=1) # GAP here
if self.task == 'seg':
outputs = list()
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if self.extra_norm:
x = self.norm_list[i](x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
x = self.norm(x)
return x.mean(dim=1) # GAP here
outputs.append(x)
return outputs
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
if self.task == 'cls':
return self.head(x)
if self.task == 'seg':
if self.F4:
x = x[3:4]
return x
def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs):
@ -372,12 +418,15 @@ def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs):
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
if kwargs.get('task', 'cls') not in ['cls', 'seg']:
raise RuntimeError('twins in timm only supports "cls" and "seg" task now.')
model = build_model_with_cfg(
Twins, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_strict=False,
**kwargs)
return model
@ -428,4 +477,4 @@ def twins_svt_large(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
Loading…
Cancel
Save