From 578d52e7522bb20eba36fa1ab341a37eb088dc67 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Thu, 5 May 2022 23:22:40 +0900 Subject: [PATCH 1/5] Add Sequencer --- timm/models/__init__.py | 1 + timm/models/sequencer.py | 389 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 390 insertions(+) create mode 100644 timm/models/sequencer.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 45ead5dc..2b5d6031 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,6 +39,7 @@ from .resnetv2 import * from .rexnet import * from .selecsls import * from .senet import * +from .sequencer import * from .sknet import * from .swin_transformer import * from .swin_transformer_v2_cr import * diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py new file mode 100644 index 00000000..88003540 --- /dev/null +++ b/timm/models/sequencer.py @@ -0,0 +1,389 @@ +""" Sequencer + +Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf + +""" +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + + +import math +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from .helpers import build_model_with_cfg, named_apply +from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"), + sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"), + sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"), +) + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)): + stdv = 1.0 / math.sqrt(module.hidden_size) + for weight in module.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + blocks = [] + for block_idx in range(layers[index]): + drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, num_layers=num_layers, + bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) + + if index < len(embed_dims) - 1: + blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) + + blocks = nn.Sequential(*blocks) + return blocks + + +class RNNIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super(RNNIdentity, self).__init__() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: + return x, None + + +class RNN2DBase(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = 2 * hidden_size if bidirectional else hidden_size + self.union = union + + self.with_vertical = True + self.with_horizontal = True + self.with_fc = with_fc + + if with_fc: + if union == "cat": + self.fc = nn.Linear(2 * self.output_size, input_size) + elif union == "add": + self.fc = nn.Linear(self.output_size, input_size) + elif union == "vertical": + self.fc = nn.Linear(self.output_size, input_size) + self.with_horizontal = False + elif union == "horizontal": + self.fc = nn.Linear(self.output_size, input_size) + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + elif union == "cat": + pass + if 2 * self.output_size != input_size: + raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.") + elif union == "add": + pass + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + elif union == "vertical": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_horizontal = False + elif union == "horizontal": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + + self.rnn_v = RNNIdentity() + self.rnn_h = RNNIdentity() + + def forward(self, x): + B, H, W, C = x.shape + + if self.with_vertical: + v = x.permute(0, 2, 1, 3) + v = v.reshape(-1, H, C) + v, _ = self.rnn_v(v) + v = v.reshape(B, W, H, -1) + v = v.permute(0, 2, 1, 3) + + if self.with_horizontal: + h = x.reshape(-1, W, C) + h, _ = self.rnn_h(h) + h = h.reshape(B, H, W, -1) + + if self.with_vertical and self.with_horizontal: + if self.union == "cat": + x = torch.cat([v, h], dim=-1) + else: + x = v + h + elif self.with_vertical: + x = v + elif self.with_horizontal: + x = h + + if self.with_fc: + x = self.fc(x) + + return x + + +class LSTM2D(RNN2DBase): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + if self.with_horizontal: + self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + + +class Sequencer2DBlock(nn.Module): + def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, + drop=0., drop_path=0.): + super().__init__() + channels_dim = int(mlp_ratio * dim) + self.norm1 = norm_layer(dim) + self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class PatchEmbed(TimmPatchEmbed): + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + else: + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.norm(x) + return x + + +class Shuffle(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + B, H, W, C = x.shape + r = torch.randperm(H * W) + x = x.reshape(B, -1, C) + x = x[:, r, :].reshape(B, H, W, -1) + return x + + +class Downsample2D(nn.Module): + def __init__(self, input_dim, output_dim, patch_size): + super().__init__() + self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.down(x) + x = x.permute(0, 2, 3, 1) + return x + + +class Sequencer2D(nn.Module): + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + block_layer=Sequencer2DBlock, + rnn_layer=LSTM2D, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_rnn_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + ): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dims[0] # num_features for consistency with other models + self.embed_dims = embed_dims + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, + flatten=False) + + self.blocks = nn.Sequential(*[ + get_stage( + i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_rnn_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate, + ) + for i, _ in enumerate(embed_dims)]) + + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=(1, 2)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + return state_dict + + +def _create_sequencer2d(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Sequencer2D models.') + + model = build_model_with_cfg( + Sequencer2D, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +# main + +@register_model +def sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_m(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 14, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l(pretrained=False, **kwargs): + model_args = dict( + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + return model From 2fec08e9232f15cee50868123067c3c9d9014014 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Fri, 6 May 2022 23:08:10 +0900 Subject: [PATCH 2/5] Add Sequencer to non std filters --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 4e50de6e..f06ddd95 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures From 93a79a3dd99f53e759525a5be945e8ba93678009 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Fri, 6 May 2022 23:16:32 +0900 Subject: [PATCH 3/5] Fix num_features in Sequencer --- timm/models/sequencer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 88003540..3ffaf02b 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -276,7 +276,7 @@ class Sequencer2D(nn.Module): ): super().__init__() self.num_classes = num_classes - self.num_features = embed_dims[0] # num_features for consistency with other models + self.num_features = embed_dims[-1] # num_features for consistency with other models self.embed_dims = embed_dims self.stem = PatchEmbed( img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, From d79f3d9d1ed1916549240a765f8cc6f958426878 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 12:09:39 -0700 Subject: [PATCH 4/5] Fix torchscript use for sequencer, add group_matcher, forward_head support, minor formatting --- timm/models/sequencer.py | 93 ++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 3ffaf02b..5fff04d1 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -71,18 +71,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals module.init_weights() -def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, - norm_layer, act_layer, num_layers, bidirectional, union, - with_fc, drop=0., drop_path_rate=0., **kwargs): +def get_stage( + index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) blocks = [] for block_idx in range(layers[index]): drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], - rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, - act_layer=act_layer, num_layers=num_layers, - bidirectional=bidirectional, union=union, with_fc=with_fc, - drop=drop, drop_path=drop_path)) + blocks.append(block_layer( + embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) if index < len(embed_dims) - 1: blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) @@ -101,9 +102,10 @@ class RNNIdentity(nn.Module): class RNN2DBase(nn.Module): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__() self.input_size = input_size @@ -115,6 +117,7 @@ class RNN2DBase(nn.Module): self.with_horizontal = True self.with_fc = with_fc + self.fc = None if with_fc: if union == "cat": self.fc = nn.Linear(2 * self.output_size, input_size) @@ -159,23 +162,27 @@ class RNN2DBase(nn.Module): v, _ = self.rnn_v(v) v = v.reshape(B, W, H, -1) v = v.permute(0, 2, 1, 3) + else: + v = None if self.with_horizontal: h = x.reshape(-1, W, C) h, _ = self.rnn_h(h) h = h.reshape(B, H, W, -1) + else: + h = None - if self.with_vertical and self.with_horizontal: + if v is not None and h is not None: if self.union == "cat": x = torch.cat([v, h], dim=-1) else: x = v + h - elif self.with_vertical: + elif v is not None: x = v - elif self.with_horizontal: + elif h is not None: x = h - if self.with_fc: + if self.fc is not None: x = self.fc(x) return x @@ -183,9 +190,10 @@ class RNN2DBase(nn.Module): class LSTM2D(RNN2DBase): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) if self.with_vertical: self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) @@ -194,10 +202,10 @@ class LSTM2D(RNN2DBase): class Sequencer2DBlock(nn.Module): - def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, - num_layers=1, bidirectional=True, union="cat", with_fc=True, - drop=0., drop_path=0.): + def __init__( + self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.): super().__init__() channels_dim = int(mlp_ratio * dim) self.norm1 = norm_layer(dim) @@ -255,6 +263,7 @@ class Sequencer2D(nn.Module): num_classes=1000, img_size=224, in_chans=3, + global_pool='avg', layers=[4, 3, 8, 3], patch_sizes=[7, 2, 1, 1], embed_dims=[192, 384, 384, 384], @@ -275,7 +284,9 @@ class Sequencer2D(nn.Module): stem_norm=False, ): super().__init__() + assert global_pool in ('', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models self.embed_dims = embed_dims self.stem = PatchEmbed( @@ -301,38 +312,54 @@ class Sequencer2D(nn.Module): head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=[ + (r'^blocks\.(\d+)\..*\.down', (99999,)), + (r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if self.global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) - x = x.mean(dim=(1, 2)) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(1, 2)) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x -def checkpoint_filter_fn(state_dict, model): - return state_dict - - def _create_sequencer2d(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Sequencer2D models.') - model = build_model_with_cfg( - Sequencer2D, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs) return model From 39b725e1c90dd5902f48b2eef2a800a3b221ca47 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 15:20:24 -0700 Subject: [PATCH 5/5] Fix tests for rank-4 output where feature channels dim is -1 (3) and not 1 --- tests/test_models.py | 10 +++++++--- timm/models/sequencer.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index f06ddd95..6489892c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -202,13 +202,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size): pytest.skip("Fixed input size model > limit.") input_tensor = torch.randn((batch_size, *input_size)) + feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): # cannot currently verify multi-tensor output. pass else: - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -216,14 +218,16 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # check classifier name matches default_cfg diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 5fff04d1..b1ae92a4 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -288,6 +288,7 @@ class Sequencer2D(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models + self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) self.embed_dims = embed_dims self.stem = PatchEmbed( img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, @@ -333,7 +334,7 @@ class Sequencer2D(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if self.global_pool is not None: + if global_pool is not None: assert global_pool in ('', 'avg') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()