From 43180b1341201eefc051e3dbbfafff1d3a0390ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=91=AB=E6=9D=B0?= Date: Tue, 25 May 2021 18:07:50 +0800 Subject: [PATCH] update twins.py to support segmentation task --- timm/models/twins.py | 81 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 16 deletions(-) diff --git a/timm/models/twins.py b/timm/models/twins.py index a534d174..dfe0208d 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -274,7 +274,7 @@ class Twins(nn.Module): self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, - block_cls=Block): + block_cls=Block, F4=False, extra_norm=True, strides=(2, 2, 2), task='cls'): super().__init__() self.num_classes = num_classes self.depths = depths @@ -290,6 +290,28 @@ class Twins(nn.Module): img_size = tuple(t // patch_size for t in img_size) patch_size = 2 + self.task = task + if self.task == 'seg': + self.F4=F4 + self.extra_norm = extra_norm + self.strides = strides + if self.extra_norm: + self.norm_list = nn.ModuleList() + for dim in embed_dims: + self.norm_list.append(norm_layer(dim)) + + if strides != (2, 2, 2): + del self.patch_embeds + self.patch_embeds = nn.ModuleList() + s = 1 + for i in range(len(depths)): + if i == 0: + self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i])) + else: + self.patch_embeds.append( + PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i])) + s = s * strides[i-1] + self.blocks = nn.ModuleList() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule cur = 0 @@ -341,24 +363,48 @@ class Twins(nn.Module): m.bias.data.zero_() def forward_features(self, x): - B = x.shape[0] - for i, (embed, drop, blocks, pos_blk) in enumerate( - zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): - x, size = embed(x) - x = drop(x) - for j, blk in enumerate(blocks): - x = blk(x, size) - if j == 0: - x = pos_blk(x, size) # PEG here - if i < len(self.depths) - 1: + + if self.task == 'cls': + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + if self.task == 'seg': + outputs = list() + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if self.extra_norm: + x = self.norm_list[i](x) x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() - x = self.norm(x) - return x.mean(dim=1) # GAP here + outputs.append(x) + return outputs def forward(self, x): x = self.forward_features(x) - x = self.head(x) - return x + if self.task == 'cls': + return self.head(x) + if self.task == 'seg': + if self.F4: + x = x[3:4] + return x + def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): @@ -372,12 +418,15 @@ def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): img_size = kwargs.pop('img_size', default_img_size) if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') + if kwargs.get('task', 'cls') not in ['cls', 'seg']: + raise RuntimeError('twins in timm only supports "cls" and "seg" task now.') model = build_model_with_cfg( Twins, variant, pretrained, default_cfg=default_cfg, img_size=img_size, num_classes=num_classes, + pretrained_strict=False, **kwargs) return model @@ -428,4 +477,4 @@ def twins_svt_large(pretrained=False, **kwargs): model_kwargs = dict( patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) \ No newline at end of file