From d79f3d9d1ed1916549240a765f8cc6f958426878 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 12:09:39 -0700 Subject: [PATCH] Fix torchscript use for sequencer, add group_matcher, forward_head support, minor formatting --- timm/models/sequencer.py | 93 ++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 3ffaf02b..5fff04d1 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -71,18 +71,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals module.init_weights() -def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, - norm_layer, act_layer, num_layers, bidirectional, union, - with_fc, drop=0., drop_path_rate=0., **kwargs): +def get_stage( + index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) blocks = [] for block_idx in range(layers[index]): drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], - rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, - act_layer=act_layer, num_layers=num_layers, - bidirectional=bidirectional, union=union, with_fc=with_fc, - drop=drop, drop_path=drop_path)) + blocks.append(block_layer( + embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) if index < len(embed_dims) - 1: blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) @@ -101,9 +102,10 @@ class RNNIdentity(nn.Module): class RNN2DBase(nn.Module): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__() self.input_size = input_size @@ -115,6 +117,7 @@ class RNN2DBase(nn.Module): self.with_horizontal = True self.with_fc = with_fc + self.fc = None if with_fc: if union == "cat": self.fc = nn.Linear(2 * self.output_size, input_size) @@ -159,23 +162,27 @@ class RNN2DBase(nn.Module): v, _ = self.rnn_v(v) v = v.reshape(B, W, H, -1) v = v.permute(0, 2, 1, 3) + else: + v = None if self.with_horizontal: h = x.reshape(-1, W, C) h, _ = self.rnn_h(h) h = h.reshape(B, H, W, -1) + else: + h = None - if self.with_vertical and self.with_horizontal: + if v is not None and h is not None: if self.union == "cat": x = torch.cat([v, h], dim=-1) else: x = v + h - elif self.with_vertical: + elif v is not None: x = v - elif self.with_horizontal: + elif h is not None: x = h - if self.with_fc: + if self.fc is not None: x = self.fc(x) return x @@ -183,9 +190,10 @@ class RNN2DBase(nn.Module): class LSTM2D(RNN2DBase): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) if self.with_vertical: self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) @@ -194,10 +202,10 @@ class LSTM2D(RNN2DBase): class Sequencer2DBlock(nn.Module): - def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, - num_layers=1, bidirectional=True, union="cat", with_fc=True, - drop=0., drop_path=0.): + def __init__( + self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.): super().__init__() channels_dim = int(mlp_ratio * dim) self.norm1 = norm_layer(dim) @@ -255,6 +263,7 @@ class Sequencer2D(nn.Module): num_classes=1000, img_size=224, in_chans=3, + global_pool='avg', layers=[4, 3, 8, 3], patch_sizes=[7, 2, 1, 1], embed_dims=[192, 384, 384, 384], @@ -275,7 +284,9 @@ class Sequencer2D(nn.Module): stem_norm=False, ): super().__init__() + assert global_pool in ('', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models self.embed_dims = embed_dims self.stem = PatchEmbed( @@ -301,38 +312,54 @@ class Sequencer2D(nn.Module): head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=[ + (r'^blocks\.(\d+)\..*\.down', (99999,)), + (r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if self.global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) - x = x.mean(dim=(1, 2)) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(1, 2)) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x -def checkpoint_filter_fn(state_dict, model): - return state_dict - - def _create_sequencer2d(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Sequencer2D models.') - model = build_model_with_cfg( - Sequencer2D, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs) return model