From 09a45ab592b036c66e41726ba5e55cf52412ddac Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Fri, 13 Aug 2021 01:36:48 +0900 Subject: [PATCH 01/21] fix a typo in ### Select specific feature levels or limit the stride There are to additional creation arguments impacting the output features. -> There are two additional creation arguments impacting the output features. --- docs/feature_extraction.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/feature_extraction.md b/docs/feature_extraction.md index b41c1559..3d638d65 100644 --- a/docs/feature_extraction.md +++ b/docs/feature_extraction.md @@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7]) ### Select specific feature levels or limit the stride -There are to additional creation arguments impacting the output features. +There are two additional creation arguments impacting the output features. * `out_indices` selects which indices to output * `output_stride` limits the feature output stride of the network (also works in classification mode BTW) From 41559247e9f282e0a1df9dca1f6173b7f8e86399 Mon Sep 17 00:00:00 2001 From: talrid Date: Mon, 22 Nov 2021 17:50:39 +0200 Subject: [PATCH 02/21] use_ml_decoder_head --- timm/models/factory.py | 5 ++ timm/models/layers/ml_decoder.py | 149 +++++++++++++++++++++++++++++++ train.py | 4 +- 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 timm/models/layers/ml_decoder.py diff --git a/timm/models/factory.py b/timm/models/factory.py index d040a9ff..40453380 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -29,6 +29,7 @@ def create_model( scriptable=None, exportable=None, no_jit=None, + use_ml_decoder_head=False, **kwargs): """Create a model @@ -80,6 +81,10 @@ def create_model( with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn(pretrained=pretrained, **kwargs) + if use_ml_decoder_head: + from timm.models.layers.ml_decoder import add_ml_decoder_head + model = add_ml_decoder_head(model) + if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py new file mode 100644 index 00000000..87815aaa --- /dev/null +++ b/timm/models/layers/ml_decoder.py @@ -0,0 +1,149 @@ +from typing import Optional + +import torch +from torch import nn +from torch import nn, Tensor +from torch.nn.modules.transformer import _get_activation_fn + + +def add_ml_decoder_head(model): + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 + model.global_pool = nn.Identity() + del model.fc + num_classes = model.num_classes + num_features = model.num_features + model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif hasattr(model, 'head'): # tresnet + del model.head + num_classes = model.num_classes + num_features = model.num_features + model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + else: + print("model is not suited for ml-decoder") + exit(-1) + + return model + + +class TransformerDecoderLayerOptimal(nn.Module): + def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", + layer_norm_eps=1e-5) -> None: + super(TransformerDecoderLayerOptimal, self).__init__() + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = torch.nn.functional.relu + super(TransformerDecoderLayerOptimal, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +# @torch.jit.script +# class ExtrapClasses(object): +# def __init__(self, num_queries: int, group_size: int): +# self.num_queries = num_queries +# self.group_size = group_size +# +# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: +# torch.Tensor): +# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) +# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) +# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) +# out = (h * w).sum(dim=2) + class_embed_b +# out = out.view((h.shape[0], self.group_size * self.num_queries)) +# return out + +@torch.jit.script +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] + w_i = duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) + + +class MLDecoder(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): + super(MLDecoder, self).__init__() + embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + if embed_len_decoder > num_classes: + embed_len_decoder = num_classes + + # switching to 768 initial embeddings + decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding + self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + + # decoder + decoder_dropout = 0.1 + num_layers_decoder = 1 + dim_feedforward = 2048 + layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + dim_feedforward=dim_feedforward, dropout=decoder_dropout) + self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) + + # non-learnable queries + self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed.requires_grad_(False) + + # group fully-connected + self.num_classes = num_classes + self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) + self.duplicate_pooling = torch.nn.Parameter( + torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) + torch.nn.init.xavier_normal_(self.duplicate_pooling) + torch.nn.init.constant_(self.duplicate_pooling_bias, 0) + self.group_fc = GroupFC(embed_len_decoder) + + def forward(self, x): + if len(x.shape) == 4: # [bs,2048, 7,7] + embedding_spatial = x.flatten(2).transpose(1, 2) + else: # [bs, 197,468] + embedding_spatial = x + embedding_spatial_786 = self.embed_standart(embedding_spatial) + embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) + + bs = embedding_spatial_786.shape[0] + query_embed = self.query_embed.weight + # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand + h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] + h = h.transpose(0, 1) + + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] + h_out += self.duplicate_pooling_bias + logits = h_out + return logits diff --git a/train.py b/train.py index 10d839be..44a0e292 100755 --- a/train.py +++ b/train.py @@ -115,6 +115,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') +parser.add_argument('--use-ml-decoder-head', type=int, default=0) # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -379,7 +380,8 @@ def main(): bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint) + checkpoint_path=args.initial_checkpoint, + ml_decoder_head=args.use_ml_decoder_head) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly From ab5ae32f75b64ba7df1aee5f7209f4dc127f0811 Mon Sep 17 00:00:00 2001 From: han Date: Wed, 24 Nov 2021 09:32:05 +0900 Subject: [PATCH 03/21] fix: typo of argment parser desc in train.py - Remove duplicated `of` --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 10d839be..a95fa473 100755 --- a/train.py +++ b/train.py @@ -108,7 +108,7 @@ parser.add_argument('--crop-pct', default=None, type=float, parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', - help='Override std deviation of of dataset') + help='Override std deviation of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', From b6c180ef19714ae33a186b1070e991e7403d81cd Mon Sep 17 00:00:00 2001 From: talrid Date: Sun, 28 Nov 2021 13:56:55 +0200 Subject: [PATCH 04/21] use_ml_decoder_head --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 44a0e292..42985e12 100755 --- a/train.py +++ b/train.py @@ -381,7 +381,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, - ml_decoder_head=args.use_ml_decoder_head) + use_ml_decoder_head=args.use_ml_decoder_head) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly From ccb3815360057f5fefaa43a6edfc7ad79bda01ac Mon Sep 17 00:00:00 2001 From: qwertyforce Date: Mon, 29 Nov 2021 21:41:00 +0300 Subject: [PATCH 05/21] update arxiv link --- timm/models/vision_transformer_hybrid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index d5f0a537..4493dcc5 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -6,7 +6,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in: - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.TODO + - https://arxiv.org/abs/2106.10270 NOTE These hybrid model definitions depend on code in vision_transformer.py. They were moved here to keep file sizes sane. @@ -360,4 +360,4 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer_hybrid( 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model From c11f4c3218958599d990fdc91c00d4560aa1c5bb Mon Sep 17 00:00:00 2001 From: talrid Date: Tue, 30 Nov 2021 08:48:08 +0200 Subject: [PATCH 06/21] support CNNs --- timm/models/layers/ml_decoder.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py index 87815aaa..3f828c6d 100644 --- a/timm/models/layers/ml_decoder.py +++ b/timm/models/layers/ml_decoder.py @@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn def add_ml_decoder_head(model): - if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 model.global_pool = nn.Identity() del model.fc num_classes = model.num_classes num_features = model.num_features model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) - elif hasattr(model, 'head'): # tresnet + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') del model.head num_classes = model.num_classes num_features = model.num_features model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) else: - print("model is not suited for ml-decoder") + print("Model code-writing is not aligned currently with ml-decoder") exit(-1) - + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 return model From cf57695938e5190f81c6d21dd15bb1ac24c4590c Mon Sep 17 00:00:00 2001 From: ayasyrev Date: Wed, 26 Jan 2022 11:53:08 +0300 Subject: [PATCH 07/21] sched noise dup code remove --- timm/scheduler/plateau_lr.py | 23 +++++------------------ timm/scheduler/scheduler.py | 32 ++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 4f2cacb6..fbfc531f 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler): min_lr=lr_min ) - self.noise_range = noise_range_t + self.noise_range_t = noise_range_t self.noise_pct = noise_pct self.noise_type = noise_type self.noise_std = noise_std @@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler): self.lr_scheduler.step(metric, epoch) # step the base scheduler - if self.noise_range is not None: - if isinstance(self.noise_range, (list, tuple)): - apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] - else: - apply_noise = epoch >= self.noise_range - if apply_noise: - self._apply_noise(epoch) + if self._is_apply_noise(epoch): + self._apply_noise(epoch) + def _apply_noise(self, epoch): - g = torch.Generator() - g.manual_seed(self.noise_seed + epoch) - if self.noise_type == 'normal': - while True: - # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + noise = self._calculate_noise(epoch) # apply the noise on top of previous LR, cache the old value so we can restore for normal # stepping of base scheduler diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 21d51509..81af76f9 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -85,21 +85,29 @@ class Scheduler: param_group[self.param_group_field] = value def _add_noise(self, lrs, t): + if self._is_apply_noise(t): + noise = self._calculate_noise(t) + lrs = [v + v * noise for v in lrs] + return lrs + + def _is_apply_noise(self, t) -> bool: + """Return True if scheduler in noise range.""" if self.noise_range_t is not None: if isinstance(self.noise_range_t, (list, tuple)): apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] else: apply_noise = t >= self.noise_range_t - if apply_noise: - g = torch.Generator() - g.manual_seed(self.noise_seed + t) - if self.noise_type == 'normal': - while True: + return apply_noise + + def _calculate_noise(self, t) -> float: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct - lrs = [v + v * noise for v in lrs] - return lrs + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + return noise + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + return noise From 7c67d6aca992f039eece0af5f7c29a43d48c00e4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 2 Feb 2022 09:15:20 -0800 Subject: [PATCH 08/21] Update README.md --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 3fa9701f..69effde4 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### Feb 2, 2022 +* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) +* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. + * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! + * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. + ### Jan 14, 2022 * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features @@ -410,6 +416,8 @@ Model validation results can be found in the [documentation](https://rwightman.g My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. +[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. + [timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. From 90dc74c450a0ec671af0e7f73d6b4a7b5396a7af Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:12:11 +0100 Subject: [PATCH 09/21] Add code from https://github.com/ChristophReich1996/Swin-Transformer-V2 and change docstring style to match timm --- timm/models/swin_transformer_v2.py | 927 +++++++++++++++++++++++++++++ 1 file changed, 927 insertions(+) create mode 100644 timm/models/swin_transformer_v2.py diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py new file mode 100644 index 00000000..c146ecc8 --- /dev/null +++ b/timm/models/swin_transformer_v2.py @@ -0,0 +1,927 @@ +""" Swin Transformer V2 +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/pdf/2111.09883 + +Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 reimplementation +# Copyright (c) 2021 Christoph Reich +# Licensed under The MIT License [see LICENSE for details] +# Written by Christoph Reich +# -------------------------------------------------------- +from typing import Tuple, Optional, List, Union, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from .layers import DropPath, Mlp + + +def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: + """ Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). + + Args: + input (torch.Tensor): Input tensor of the shape (B, C, H, W) + + Returns: + output (torch.Tensor): Permuted tensor of the shape (B, H, W, C) + """ + output: torch.Tensor = input.permute(0, 2, 3, 1) + return output + + +def bhwc_to_bchw(input: torch.Tensor) -> torch.Tensor: + """ Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). + + Args: + input (torch.Tensor): Input tensor of the shape (B, H, W, C) + + Returns: + output (torch.Tensor): Permuted tensor of the shape (B, C, H, W) + """ + output: torch.Tensor = input.permute(0, 3, 1, 2) + return output + + +def unfold(input: torch.Tensor, + window_size: int) -> torch.Tensor: + """ Unfolds (non-overlapping) a given feature map by the given window size (stride = window size). + + Args: + input (torch.Tensor): Input feature map of the shape (B, C, H, W) + window_size (int): Window size to be applied + + Returns: + output (torch.Tensor): Unfolded tensor of the shape [B * windows, C, window size, window size] + """ + # Get original shape + _, channels, height, width = input.shape # type: int, int, int, int + # Unfold input + output: torch.Tensor = input.unfold(dimension=3, size=window_size, step=window_size) \ + .unfold(dimension=2, size=window_size, step=window_size) + # Reshape to (B * windows, C, window size, window size) + output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size, window_size) + return output + + +def fold(input: torch.Tensor, + window_size: int, + height: int, + width: int) -> torch.Tensor: + """ Folds a tensor of windows again to a 4D feature map. + + Args: + input (torch.Tensor): Input feature map of the shape (B, C, H, W) + window_size (int): Window size of the unfold operation + height (int): Height of the feature map + width (int): Width of the feature map + + Returns: + output (torch.Tensor): Folded output tensor of the shape (B, C, H, W) + """ + # Get channels of windows + channels: int = input.shape[1] + # Get original batch size + batch_size: int = int(input.shape[0] // (height * width // window_size // window_size)) + # Reshape input to (B, C, H, W) + output: torch.Tensor = input.view(batch_size, height // window_size, width // window_size, channels, + window_size, window_size) + output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width) + return output + + +class WindowMultiHeadAttention(nn.Module): + r""" This class implements window-based Multi-Head-Attention with log-spaced continuous position bias. + + Args: + in_features (int): Number of input features + window_size (int): Window size + number_of_heads (int): Number of attention heads + dropout_attention (float): Dropout rate of attention map + dropout_projection (float): Dropout rate after projection + meta_network_hidden_features (int): Number of hidden features in the two layer MLP meta network + sequential_self_attention (bool): If true sequential self-attention is performed + """ + + def __init__(self, + in_features: int, + window_size: int, + number_of_heads: int, + dropout_attention: float = 0., + dropout_projection: float = 0., + meta_network_hidden_features: int = 256, + sequential_self_attention: bool = False) -> None: + # Call super constructor + super(WindowMultiHeadAttention, self).__init__() + # Check parameter + assert (in_features % number_of_heads) == 0, \ + "The number of input features (in_features) are not divisible by the number of heads (number_of_heads)." + # Save parameters + self.in_features: int = in_features + self.window_size: int = window_size + self.number_of_heads: int = number_of_heads + self.sequential_self_attention: bool = sequential_self_attention + # Init query, key and value mapping as a single layer + self.mapping_qkv: nn.Module = nn.Linear(in_features=in_features, out_features=in_features * 3, bias=True) + # Init attention dropout + self.attention_dropout: nn.Module = nn.Dropout(dropout_attention) + # Init projection mapping + self.projection: nn.Module = nn.Linear(in_features=in_features, out_features=in_features, bias=True) + # Init projection dropout + self.projection_dropout: nn.Module = nn.Dropout(dropout_projection) + # Init meta network for positional encodings + self.meta_network: nn.Module = nn.Sequential( + nn.Linear(in_features=2, out_features=meta_network_hidden_features, bias=True), + nn.ReLU(inplace=True), + nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads, bias=True)) + # Init tau + self.register_parameter("tau", torch.nn.Parameter(torch.ones(1, number_of_heads, 1, 1))) + # Init pair-wise relative positions (log-spaced) + self.__make_pair_wise_relative_positions() + + def __make_pair_wise_relative_positions(self) -> None: + """ Method initializes the pair-wise relative positions to compute the positional biases.""" + indexes: torch.Tensor = torch.arange(self.window_size, device=self.tau.device) + coordinates: torch.Tensor = torch.stack(torch.meshgrid([indexes, indexes]), dim=0) + coordinates: torch.Tensor = torch.flatten(coordinates, start_dim=1) + relative_coordinates: torch.Tensor = coordinates[:, :, None] - coordinates[:, None, :] + relative_coordinates: torch.Tensor = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() + relative_coordinates_log: torch.Tensor = torch.sign(relative_coordinates) \ + * torch.log(1. + relative_coordinates.abs()) + self.register_buffer("relative_coordinates_log", relative_coordinates_log) + + def update_resolution(self, + new_window_size: int, + **kwargs: Any) -> None: + """ Method updates the window size and so the pair-wise relative positions + + Args: + new_window_size (int): New window size + kwargs (Any): Unused + """ + # Set new window size + self.window_size: int = new_window_size + # Make new pair-wise relative positions + self.__make_pair_wise_relative_positions() + + def __get_relative_positional_encodings(self) -> torch.Tensor: + """ Method computes the relative positional encodings + + Returns: + relative_position_bias (torch.Tensor): Relative positional encodings + (1, number of heads, window size ** 2, window size ** 2) + """ + relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log) + relative_position_bias: torch.Tensor = relative_position_bias.permute(1, 0) + relative_position_bias: torch.Tensor = relative_position_bias.reshape(self.number_of_heads, + self.window_size * self.window_size, + self.window_size * self.window_size) + relative_position_bias: torch.Tensor = relative_position_bias.unsqueeze(0) + return relative_position_bias + + def __self_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + batch_size_windows: int, + tokens: int, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ This function performs standard (non-sequential) scaled cosine self-attention. + + Args: + query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] + key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] + value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) + batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows + tokens (int): Number of tokens in the input + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] + """ + # Compute attention map with scaled cosine attention + attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \ + / torch.maximum(torch.norm(query, dim=-1, keepdim=True) + * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1), + torch.tensor(1e-06, device=query.device, dtype=query.dtype)) + attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01) + # Apply relative positional encodings + attention_map: torch.Tensor = attention_map + self.__get_relative_positional_encodings() + # Apply mask if utilized + if mask is not None: + number_of_windows: int = mask.shape[0] + attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, number_of_windows, + self.number_of_heads, tokens, tokens) + attention_map: torch.Tensor = attention_map + mask.unsqueeze(1).unsqueeze(0) + attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens, tokens) + attention_map: torch.Tensor = attention_map.softmax(dim=-1) + # Perform attention dropout + attention_map: torch.Tensor = self.attention_dropout(attention_map) + # Apply attention map and reshape + output: torch.Tensor = torch.einsum("bhal, bhlv -> bhav", attention_map, value) + output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) + return output + + def __sequential_self_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + batch_size_windows: int, + tokens: int, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ This function performs sequential scaled cosine self-attention. + + Args: + query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] + key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] + value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) + batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows + tokens (int): Number of tokens in the input + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] + """ + # Init output tensor + output: torch.Tensor = torch.ones_like(query) + # Compute relative positional encodings fist + relative_position_bias: torch.Tensor = self.__get_relative_positional_encodings() + # Iterate over query and key tokens + for token_index_query in range(tokens): + # Compute attention map with scaled cosine attention + attention_map: torch.Tensor = \ + torch.einsum("bhd, bhkd -> bhk", query[:, :, token_index_query], key) \ + / torch.maximum(torch.norm(query[:, :, token_index_query], dim=-1, keepdim=True) + * torch.norm(key, dim=-1, keepdim=False), + torch.tensor(1e-06, device=query.device, dtype=query.dtype)) + attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)[..., 0] + # Apply positional encodings + attention_map: torch.Tensor = attention_map + relative_position_bias[..., token_index_query, :] + # Apply mask if utilized + if mask is not None: + number_of_windows: int = mask.shape[0] + attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, + number_of_windows, self.number_of_heads, 1, + tokens) + attention_map: torch.Tensor = attention_map \ + + mask.unsqueeze(1).unsqueeze(0)[..., token_index_query, :].unsqueeze(3) + attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens) + attention_map: torch.Tensor = attention_map.softmax(dim=-1) + # Perform attention dropout + attention_map: torch.Tensor = self.attention_dropout(attention_map) + # Apply attention map and reshape + output[:, :, token_index_query] = torch.einsum("bhl, bhlv -> bhv", attention_map, value) + output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) + return output + + def forward(self, + input: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape (B * windows, C, H, W) + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output tensor of the shape [B * windows, C, H, W] + """ + # Save original shape + batch_size_windows, channels, height, width = input.shape # type: int, int, int, int + tokens: int = height * width + # Reshape input to (B * windows, tokens (height * width), C) + input: torch.Tensor = input.reshape(batch_size_windows, channels, tokens).permute(0, 2, 1) + # Perform query, key, and value mapping + query_key_value: torch.Tensor = self.mapping_qkv(input) + query_key_value: torch.Tensor = query_key_value.view(batch_size_windows, tokens, 3, self.number_of_heads, + channels // self.number_of_heads).permute(2, 0, 3, 1, 4) + query, key, value = query_key_value[0], query_key_value[1], query_key_value[2] + # Perform attention + if self.sequential_self_attention: + output: torch.Tensor = self.__sequential_self_attention(query=query, key=key, value=value, + batch_size_windows=batch_size_windows, + tokens=tokens, + mask=mask) + else: + output: torch.Tensor = self.__self_attention(query=query, key=key, value=value, + batch_size_windows=batch_size_windows, tokens=tokens, + mask=mask) + # Perform linear mapping and dropout + output: torch.Tensor = self.projection_dropout(self.projection(output)) + # Reshape output to original shape [B * windows, C, H, W] + output: torch.Tensor = output.permute(0, 2, 1).view(batch_size_windows, channels, height, width) + return output + + +class SwinTransformerBlock(nn.Module): + r""" This class implements the Swin transformer block. + + Args: + in_channels (int): Number of input channels + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + sequential_self_attention (bool): If true sequential self-attention is performed + """ + + def __init__(self, + in_channels: int, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + shift_size: int = 0, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.0, + sequential_self_attention: bool = False) -> None: + # Call super constructor + super(SwinTransformerBlock, self).__init__() + # Save parameters + self.in_channels: int = in_channels + self.input_resolution: Tuple[int, int] = input_resolution + # Catch case if resolution is smaller than the window size + if min(self.input_resolution) <= window_size: + self.window_size: int = min(self.input_resolution) + self.shift_size: int = 0 + self.make_windows: bool = False + else: + self.window_size: int = window_size + self.shift_size: int = shift_size + self.make_windows: bool = True + # Init normalization layers + self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + # Init window attention module + self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention( + in_features=in_channels, + window_size=self.window_size, + number_of_heads=number_of_heads, + dropout_attention=dropout_attention, + dropout_projection=dropout, + sequential_self_attention=sequential_self_attention) + # Init dropout layer + self.dropout: nn.Module = DropPath(drop_prob=dropout_path) if dropout_path > 0. else nn.Identity() + # Init feed-forward network + self.feed_forward_network: nn.Module = Mlp(in_features=in_channels, + hidden_features=int(in_channels * ff_feature_ratio), + drop=dropout, + out_features=in_channels) + # Make attention mask + self.__make_attention_mask() + + def __make_attention_mask(self) -> None: + """ Method generates the attention mask used in shift case. """ + # Make masks for shift case + if self.shift_size > 0: + height, width = self.input_resolution # type: int, int + mask: torch.Tensor = torch.zeros(height, width, device=self.window_attention.tau.device) + height_slices: Tuple = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + width_slices: Tuple = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + counter: int = 0 + for height_slice in height_slices: + for width_slice in width_slices: + mask[height_slice, width_slice] = counter + counter += 1 + mask_windows: torch.Tensor = unfold(mask[None, None], self.window_size) + mask_windows: torch.Tensor = mask_windows.reshape(-1, self.window_size * self.window_size) + attention_mask: Optional[torch.Tensor] = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask != 0, float(-100.0)) + attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask == 0, float(0.0)) + else: + attention_mask: Optional[torch.Tensor] = None + # Save mask + self.register_buffer("attention_mask", attention_mask) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update input resolution + self.input_resolution: Tuple[int, int] = new_input_resolution + # Catch case if resolution is smaller than the window size + if min(self.input_resolution) <= new_window_size: + self.window_size: int = min(self.input_resolution) + self.shift_size: int = 0 + self.make_windows: bool = False + else: + self.window_size: int = new_window_size + self.shift_size: int = self.shift_size + self.make_windows: bool = True + # Update attention mask + self.__make_attention_mask() + # Update attention module + self.window_attention.update_resolution(new_window_size=new_window_size) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, C, H, W] + """ + # Save shape + batch_size, channels, height, width = input.shape # type: int, int, int, int + # Shift input if utilized + if self.shift_size > 0: + output_shift: torch.Tensor = torch.roll(input=input, shifts=(-self.shift_size, -self.shift_size), + dims=(-1, -2)) + else: + output_shift: torch.Tensor = input + # Make patches + output_patches: torch.Tensor = unfold(input=output_shift, window_size=self.window_size) \ + if self.make_windows else output_shift + # Perform window attention + output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask) + # Merge patches + output_merge: torch.Tensor = fold(input=output_attention, window_size=self.window_size, height=height, + width=width) if self.make_windows else output_attention + # Reverse shift if utilized + if self.shift_size > 0: + output_shift: torch.Tensor = torch.roll(input=output_merge, shifts=(self.shift_size, self.shift_size), + dims=(-1, -2)) + else: + output_shift: torch.Tensor = output_merge + # Perform normalization + output_normalize: torch.Tensor = self.normalization_1(output_shift.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + # Skip connection + output_skip: torch.Tensor = self.dropout(output_normalize) + input + # Feed forward network, normalization and skip connection + output_feed_forward: torch.Tensor = self.feed_forward_network( + output_skip.view(batch_size, channels, -1).permute(0, 2, 1)).permute(0, 2, 1) + output_feed_forward: torch.Tensor = output_feed_forward.view(batch_size, channels, height, width) + output_normalize: torch.Tensor = bhwc_to_bchw(self.normalization_2(bchw_to_bhwc(output_feed_forward))) + output: torch.Tensor = output_skip + self.dropout(output_normalize) + return output + + +class DeformableSwinTransformerBlock(SwinTransformerBlock): + r""" This class implements a deformable version of the Swin Transformer block. + Inspired by: https://arxiv.org/pdf/2201.00520 + + Args: + in_channels (int): Number of input channels + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + sequential_self_attention (bool): If true sequential self-attention is performed + offset_downscale_factor (int): Downscale factor of offset network + """ + + def __init__(self, + in_channels: int, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + shift_size: int = 0, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.0, + sequential_self_attention: bool = False, + offset_downscale_factor: int = 2) -> None: + # Call super constructor + super(DeformableSwinTransformerBlock, self).__init__( + in_channels=in_channels, + input_resolution=input_resolution, + number_of_heads=number_of_heads, + window_size=window_size, + shift_size=shift_size, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path, + sequential_self_attention=sequential_self_attention + ) + # Save parameter + self.offset_downscale_factor: int = offset_downscale_factor + self.number_of_heads: int = number_of_heads + # Make default offsets + self.__make_default_offsets() + # Init offset network + self.offset_network: nn.Module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor, + padding=3, groups=in_channels, bias=True), + nn.GELU(), + nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1, + padding=0, bias=True) + ) + + def __make_default_offsets(self) -> None: + """ Method generates the default sampling grid (inspired by kornia) """ + # Init x and y coordinates + x: torch.Tensor = torch.linspace(0, self.input_resolution[1] - 1, self.input_resolution[1], + device=self.window_attention.tau.device) + y: torch.Tensor = torch.linspace(0, self.input_resolution[0] - 1, self.input_resolution[0], + device=self.window_attention.tau.device) + # Normalize coordinates to a range of [-1, 1] + x: torch.Tensor = (x / (self.input_resolution[1] - 1) - 0.5) * 2 + y: torch.Tensor = (y / (self.input_resolution[0] - 1) - 0.5) * 2 + # Make grid [2, height, width] + grid: torch.Tensor = torch.stack(torch.meshgrid([x, y])).transpose(1, 2) + # Reshape grid to [1, height, width, 2] + grid: torch.Tensor = grid.unsqueeze(dim=0).permute(0, 2, 3, 1) + # Register in module + self.register_buffer("default_grid", grid) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update resolution and window size + super(DeformableSwinTransformerBlock, self).update_resolution(new_window_size=new_window_size, + new_input_resolution=new_input_resolution) + # Update default sampling grid + self.__make_default_offsets() + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, C, H, W] + """ + # Get input shape + batch_size, channels, height, width = input.shape + # Compute offsets of the shape [batch size, 2, height / r, width / r] + offsets: torch.Tensor = self.offset_network(input) + # Upscale offsets to the shape [batch size, 2 * number of heads, height, width] + offsets: torch.Tensor = F.interpolate(input=offsets, + size=(height, width), mode="bilinear", align_corners=True) + # Reshape offsets to [batch size, number of heads, height, width, 2] + offsets: torch.Tensor = offsets.reshape(batch_size, -1, 2, height, width).permute(0, 1, 3, 4, 2) + # Flatten batch size and number of heads and apply tanh + offsets: torch.Tensor = offsets.view(-1, height, width, 2).tanh() + # Cast offset grid to input data type + if input.dtype != self.default_grid.dtype: + self.default_grid = self.default_grid.type(input.dtype) + # Construct offset grid + offset_grid: torch.Tensor = self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0) + offsets + # Reshape input to [batch size * number of heads, channels / number of heads, height, width] + input: torch.Tensor = input.view(batch_size, self.number_of_heads, channels // self.number_of_heads, height, + width).flatten(start_dim=0, end_dim=1) + # Apply sampling grid + input_resampled: torch.Tensor = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1), + mode="bilinear", align_corners=True, padding_mode="reflection") + # Reshape resampled tensor again to [batch size, channels, height, width] + input_resampled: torch.Tensor = input_resampled.view(batch_size, channels, height, width) + output: torch.Tensor = super(DeformableSwinTransformerBlock, self).forward(input=input_resampled) + return output + + +class PatchMerging(nn.Module): + """ This class implements the patch merging as a strided convolution with a normalization before. + + Args: + in_channels (int): Number of input channels + """ + + def __init__(self, + in_channels: int) -> None: + # Call super constructor + super(PatchMerging, self).__init__() + # Init normalization + self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels) + # Init linear mapping + self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels, + bias=False) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + # Get original shape + batch_size, channels, height, width = input.shape # type: int, int, int, int + # Reshape input to [batch size, in channels, height, width] + input: torch.Tensor = bchw_to_bhwc(input) + # Unfold input + input: torch.Tensor = input.unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) + input: torch.Tensor = input.reshape(batch_size, input.shape[1], input.shape[2], -1) + # Normalize input + input: torch.Tensor = self.normalization(input) + # Perform linear mapping + output: torch.Tensor = bhwc_to_bchw(self.linear_mapping(input)) + return output + + +class PatchEmbedding(nn.Module): + """ Module embeds a given image into patch embeddings. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + patch_size (int): Patch size to be utilized + image_size (int): Image size to be used + """ + + def __init__(self, + in_channels: int = 3, + out_channels: int = 96, + patch_size: int = 4) -> None: + # Call super constructor + super(PatchEmbedding, self).__init__() + # Save parameters + self.out_channels: int = out_channels + # Init linear embedding as a convolution + self.linear_embedding: nn.Module = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size)) + # Init layer normalization + self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input image of the shape (B, C_in, H, W) + + Returns: + embedding (torch.Tensor): Embedding of the shape (B, C_out, H / patch size, W / patch size) + """ + # Perform linear embedding + embedding: torch.Tensor = self.linear_embedding(input) + # Perform normalization + embedding: torch.Tensor = bhwc_to_bchw(self.normalization(bchw_to_bhwc(embedding))) + return embedding + + +class SwinTransformerStage(nn.Module): + r""" This class implements a stage of the Swin transformer including multiple layers. + + Args: + in_channels (int): Number of input channels + depth (int): Depth of the stage (number of layers) + downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + use_checkpoint (bool): If true checkpointing is utilized + sequential_self_attention (bool): If true sequential self-attention is performed + use_deformable_block (bool): If true deformable block is used + """ + + def __init__(self, + in_channels: int, + depth: int, + downscale: bool, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: Union[List[float], float] = 0.0, + use_checkpoint: bool = False, + sequential_self_attention: bool = False, + use_deformable_block: bool = False) -> None: + # Call super constructor + super(SwinTransformerStage, self).__init__() + # Save parameters + self.use_checkpoint: bool = use_checkpoint + self.downscale: bool = downscale + # Init downsampling + self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity() + # Update resolution and channels + self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \ + if downscale else input_resolution + in_channels = in_channels * 2 if downscale else in_channels + # Get block + block = DeformableSwinTransformerBlock if use_deformable_block else SwinTransformerBlock + # Init blocks + self.blocks: nn.ModuleList = nn.ModuleList([ + block(in_channels=in_channels, + input_resolution=self.input_resolution, + number_of_heads=number_of_heads, + window_size=window_size, + shift_size=0 if ((index % 2) == 0) else window_size // 2, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path, + sequential_self_attention=sequential_self_attention) + for index in range(depth)]) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update resolution + self.input_resolution: Tuple[int, int] = (new_input_resolution[0] // 2, new_input_resolution[1] // 2) \ + if self.downscale else new_input_resolution + # Update resolution of each block + for block in self.blocks: # type: SwinTransformerBlock + block.update_resolution(new_window_size=new_window_size, new_input_resolution=self.input_resolution) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + # Downscale input tensor + output: torch.Tensor = self.downsample(input) + # Forward pass of each block + for block in self.blocks: # type: nn.Module + # Perform checkpointing if utilized + if self.use_checkpoint: + output: torch.Tensor = checkpoint.checkpoint(block, output) + else: + output: torch.Tensor = block(output) + return output + + +class SwinTransformerV2(nn.Module): + r""" Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - + https://arxiv.org/pdf/2111.09883 + + Args: + in_channels (int): Number of input channels + depth (int): Depth of the stage (number of layers) + downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + num_classes (int): Number of output classes + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + use_checkpoint (bool): If true checkpointing is utilized + sequential_self_attention (bool): If true sequential self-attention is performed + use_deformable_block (bool): If true deformable block is used + """ + + def __init__(self, + in_channels: int, + embedding_channels: int, + depths: Tuple[int, ...], + input_resolution: Tuple[int, int], + number_of_heads: Tuple[int, ...], + num_classes: int = 1000, + window_size: int = 7, + patch_size: int = 4, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.2, + use_checkpoint: bool = False, + sequential_self_attention: bool = False, + use_deformable_block: bool = False) -> None: + # Call super constructor + super(SwinTransformerV2, self).__init__() + # Save parameters + self.patch_size: int = patch_size + self.input_resolution: Tuple[int, int] = input_resolution + self.window_size: int = window_size + # Init patch embedding + self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels, + patch_size=patch_size) + # Compute patch resolution + patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size) + # Path dropout dependent on depth + dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist() + # Init stages + self.stages: nn.ModuleList = nn.ModuleList() + for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)): + self.stages.append( + SwinTransformerStage( + in_channels=embedding_channels * (2 ** max(index - 1, 0)), + depth=depth, + downscale=index != 0, + input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), + patch_resolution[1] // (2 ** max(index - 1, 0))), + number_of_heads=number_of_head, + window_size=window_size, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])], + use_checkpoint=use_checkpoint, + sequential_self_attention=sequential_self_attention, + use_deformable_block=use_deformable_block and (index > 0) + )) + # Init final adaptive average pooling, and classification head + self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) + self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1), + out_features=num_classes) + + def update_resolution(self, + new_input_resolution: Optional[Tuple[int, int]] = None, + new_window_size: Optional[int] = None) -> None: + """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (Optional[int]): New window size if None current window size is used + new_input_resolution (Optional[Tuple[int, int]]): New input resolution if None current resolution is used + """ + # Check parameters + if new_input_resolution is None: + new_input_resolution = self.input_resolution + if new_window_size is None: + new_window_size = self.window_size + # Compute new patch resolution + new_patch_resolution: Tuple[int, int] = (new_input_resolution[0] // self.patch_size, + new_input_resolution[1] // self.patch_size) + # Update resolution of each stage + for index, stage in enumerate(self.stages): # type: int, SwinTransformerStage + stage.update_resolution(new_window_size=new_window_size, + new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)), + new_patch_resolution[1] // (2 ** max(index - 1, 0)))) + + def forward_features(self, + input: torch.Tensor) -> List[torch.Tensor]: + """ Forward pass to extract feature maps of each stage. + + Args: + input (torch.Tensor): Input images of the shape (B, C, H, W) + + Returns: + features (List[torch.Tensor]): List of feature maps from each stage + """ + # Perform patch embedding + output: torch.Tensor = self.patch_embedding(input) + # Init list to store feature + features: List[torch.Tensor] = [] + # Forward pass of each stage + for stage in self.stages: + output: torch.Tensor = stage(output) + features.append(output) + return features + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input images of the shape (B, C, H, W) + + Returns: + classification (torch.Tensor): Classification of the shape (B, num_classes) + """ + # Perform patch embedding + output: torch.Tensor = self.patch_embedding(input) + # Forward pass of each stage + for stage in self.stages: + output: torch.Tensor = stage(output) + # Perform average pooling + output: torch.Tensor = self.average_pool(output) + # Predict classification + classification: torch.Tensor = self.head(output) + return classification From f227b88831b55ba0bec754c92c2a279f2c5ecd3e Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:14:38 +0100 Subject: [PATCH 10/21] Add initials (CR) to model and file --- .../{swin_transformer_v2.py => swin_transformer_v2_cr.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename timm/models/{swin_transformer_v2.py => swin_transformer_v2_cr.py} (99%) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2_cr.py similarity index 99% rename from timm/models/swin_transformer_v2.py rename to timm/models/swin_transformer_v2_cr.py index c146ecc8..d7c36519 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -785,7 +785,7 @@ class SwinTransformerStage(nn.Module): return output -class SwinTransformerV2(nn.Module): +class SwinTransformerV2CR(nn.Module): r""" Swin Transformer V2 A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - https://arxiv.org/pdf/2111.09883 @@ -825,7 +825,7 @@ class SwinTransformerV2(nn.Module): sequential_self_attention: bool = False, use_deformable_block: bool = False) -> None: # Call super constructor - super(SwinTransformerV2, self).__init__() + super(SwinTransformerV2CR, self).__init__() # Save parameters self.patch_size: int = patch_size self.input_resolution: Tuple[int, int] = input_resolution From 81bf0b403384b45680ecc303b50d482c3d1dcd5f Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:37:22 +0100 Subject: [PATCH 11/21] Change parameter names to match Swin V1 --- timm/models/swin_transformer_v2_cr.py | 76 ++++++++++++++------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d7c36519..9659b5ec 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -12,7 +12,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Licensed under The MIT License [see LICENSE for details] # Written by Christoph Reich # -------------------------------------------------------- -from typing import Tuple, Optional, List, Union, Any +from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn @@ -717,6 +717,7 @@ class SwinTransformerStage(nn.Module): dropout: float = 0.0, dropout_attention: float = 0.0, dropout_path: Union[List[float], float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, use_checkpoint: bool = False, sequential_self_attention: bool = False, use_deformable_block: bool = False) -> None: @@ -791,75 +792,78 @@ class SwinTransformerV2CR(nn.Module): https://arxiv.org/pdf/2111.09883 Args: - in_channels (int): Number of input channels - depth (int): Depth of the stage (number of layers) - downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) - input_resolution (Tuple[int, int]): Input resolution - number_of_heads (int): Number of attention heads to be utilized - num_classes (int): Number of output classes - window_size (int): Window size to be utilized - shift_size (int): Shifting size to be used - ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - dropout (float): Dropout in input mapping - dropout_attention (float): Dropout rate of attention map - dropout_path (float): Dropout in main path - use_checkpoint (bool): If true checkpointing is utilized - sequential_self_attention (bool): If true sequential self-attention is performed - use_deformable_block (bool): If true deformable block is used + img_size (Tuple[int, int]): Input resolution. + in_chans (int): Number of input channels. + depths (int): Depth of the stage (number of layers). + num_heads (int): Number of attention heads to be utilized. + embed_dim (int): Patch embedding dimension. Default: 96 + num_classes (int): Number of output classes. Default: 1000 + window_size (int): Window size to be utilized. Default: 7 + patch_size (int | tuple(int)): Patch size. Default: 4 + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4 + drop_rate (float): Dropout rate. Default: 0.0 + attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default: 0.0 + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm + use_checkpoint (bool): If true checkpointing is utilized. Default: False + sequential_self_attention (bool): If true sequential self-attention is performed. Default: False + use_deformable_block (bool): If true deformable block is used. Default: False """ def __init__(self, - in_channels: int, - embedding_channels: int, + img_size: Tuple[int, int], + in_chans: int, depths: Tuple[int, ...], - input_resolution: Tuple[int, int], - number_of_heads: Tuple[int, ...], + num_heads: Tuple[int, ...], + embed_dim: int = 96, num_classes: int = 1000, window_size: int = 7, patch_size: int = 4, - ff_feature_ratio: int = 4, - dropout: float = 0.0, - dropout_attention: float = 0.0, - dropout_path: float = 0.2, + mlp_ratio: int = 4, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, use_checkpoint: bool = False, sequential_self_attention: bool = False, - use_deformable_block: bool = False) -> None: + use_deformable_block: bool = False, + **kwargs: Any) -> None: # Call super constructor super(SwinTransformerV2CR, self).__init__() # Save parameters self.patch_size: int = patch_size - self.input_resolution: Tuple[int, int] = input_resolution + self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size # Init patch embedding - self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels, + self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size) # Compute patch resolution - patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size) + patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size) # Path dropout dependent on depth - dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist() + drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist() # Init stages self.stages: nn.ModuleList = nn.ModuleList() - for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)): + for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)): self.stages.append( SwinTransformerStage( - in_channels=embedding_channels * (2 ** max(index - 1, 0)), + in_channels=embed_dim * (2 ** max(index - 1, 0)), depth=depth, downscale=index != 0, input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), patch_resolution[1] // (2 ** max(index - 1, 0))), number_of_heads=number_of_head, window_size=window_size, - ff_feature_ratio=ff_feature_ratio, - dropout=dropout, - dropout_attention=dropout_attention, - dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])], + ff_feature_ratio=mlp_ratio, + dropout=drop_rate, + dropout_attention=attn_drop_rate, + dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], use_checkpoint=use_checkpoint, sequential_self_attention=sequential_self_attention, use_deformable_block=use_deformable_block and (index > 0) )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) - self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1), + self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1), out_features=num_classes) def update_resolution(self, From ff5f6bcd6cb451418e93164ca43c18adfd92f36f Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:42:02 +0100 Subject: [PATCH 12/21] Check input resolution --- timm/models/swin_transformer_v2_cr.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9659b5ec..6b0ca1c1 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module): Returns: features (List[torch.Tensor]): List of feature maps from each stage """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Init list to store feature @@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module): Returns: classification (torch.Tensor): Classification of the shape (B, num_classes) """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Forward pass of each stage From 87b4d7a29af80d7c1334af36392f744e9382cdf0 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:47:02 +0100 Subject: [PATCH 13/21] Add get and reset classifier method --- timm/models/swin_transformer_v2_cr.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 6b0ca1c1..7adf1ec0 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -831,9 +831,11 @@ class SwinTransformerV2CR(nn.Module): # Call super constructor super(SwinTransformerV2CR, self).__init__() # Save parameters + self.num_classes: int = num_classes self.patch_size: int = patch_size self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size + self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size) @@ -863,7 +865,7 @@ class SwinTransformerV2CR(nn.Module): )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) - self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1), + self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) def update_resolution(self, @@ -889,6 +891,25 @@ class SwinTransformerV2CR(nn.Module): new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)), new_patch_resolution[1] // (2 ** max(index - 1, 0)))) + def get_classifier(self) -> nn.Module: + """ Method returns the classification head of the model. + Returns: + head (nn.Module): Current classification head + """ + head: nn.Module = self.head + return head + + def reset_classifier(self, num_classes: int, global_pool: str = '') -> None: + """ Method results the classification head + + Args: + num_classes (int): Number of classes to be predicted + global_pool (str): Unused + """ + self.num_classes: int = num_classes + self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) \ + if num_classes > 0 else nn.Identity() + def forward_features(self, input: torch.Tensor) -> List[torch.Tensor]: """ Forward pass to extract feature maps of each stage. From 2a4f6c13dd75df4a2439e9ba26dc60b38926a62b Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 00:40:22 +0100 Subject: [PATCH 14/21] Create model functions --- timm/models/swin_transformer_v2_cr.py | 223 +++++++++++++++++++++++++- 1 file changed, 222 insertions(+), 1 deletion(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 7adf1ec0..9aad19c0 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -12,6 +12,8 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Licensed under The MIT License [see LICENSE for details] # Written by Christoph Reich # -------------------------------------------------------- +import logging +from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch @@ -19,7 +21,81 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from .layers import DropPath, Mlp +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +# from .helpers import build_model_with_cfg, overlay_external_default_cfg +# from .vision_transformer import checkpoint_filter_fn +# from .registry import register_model +# from .layers import DropPath, Mlp + +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.vision_transformer import checkpoint_filter_fn +from timm.models.registry import register_model +from timm.models.layers import DropPath, Mlp + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_v2_cr_tiny_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_tiny_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_small_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_small_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_base_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_base_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_large_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_large_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_huge_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_huge_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), + + 'swin_v2_cr_giant_patch4_window12_384': _cfg( + url="", + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_v2_cr_giant_patch4_window7_224': _cfg( + url="", + input_size=(3, 224, 224), crop_pct=1.0), +} def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: @@ -958,3 +1034,148 @@ class SwinTransformerV2CR(nn.Module): # Predict classification classification: torch.Tensor = self.head(output) return classification + + +def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformerV2CR, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-T V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-T V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-S V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-S V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-B V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-B V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-L V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-L V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-H V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-H V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2), + num_heads=(11, 22, 44, 88), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-G V2 CR @ 384x384, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2), + num_heads=(16, 32, 64, 128), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-G V2 CR @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), **kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +if __name__ == '__main__': + model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False) + model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_small_patch4_window12_384(pretrained=False) + model = swin_v2_cr_small_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_base_patch4_window12_384(pretrained=False) + model = swin_v2_cr_base_patch4_window7_224(pretrained=False) + + model = swin_v2_cr_large_patch4_window12_384(pretrained=False) + model = swin_v2_cr_large_patch4_window7_224(pretrained=False) From 74a04e0016f6f3e3b6484236ab8815561e3c5a86 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 00:46:00 +0100 Subject: [PATCH 15/21] Add parameter to change normalization type --- timm/models/swin_transformer_v2_cr.py | 53 +++++++++++++-------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9aad19c0..6df1ff9b 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -408,6 +408,7 @@ class SwinTransformerBlock(nn.Module): dropout_attention (float): Dropout rate of attention map dropout_path (float): Dropout in main path sequential_self_attention (bool): If true sequential self-attention is performed + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, @@ -420,7 +421,8 @@ class SwinTransformerBlock(nn.Module): dropout: float = 0.0, dropout_attention: float = 0.0, dropout_path: float = 0.0, - sequential_self_attention: bool = False) -> None: + sequential_self_attention: bool = False, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(SwinTransformerBlock, self).__init__() # Save parameters @@ -436,8 +438,8 @@ class SwinTransformerBlock(nn.Module): self.shift_size: int = shift_size self.make_windows: bool = True # Init normalization layers - self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels) - self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + self.normalization_1: nn.Module = norm_layer(normalized_shape=in_channels) + self.normalization_2: nn.Module = norm_layer(normalized_shape=in_channels) # Init window attention module self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention( in_features=in_channels, @@ -569,6 +571,7 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout_path (float): Dropout in main path sequential_self_attention (bool): If true sequential self-attention is performed offset_downscale_factor (int): Downscale factor of offset network + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, @@ -582,7 +585,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout_attention: float = 0.0, dropout_path: float = 0.0, sequential_self_attention: bool = False, - offset_downscale_factor: int = 2) -> None: + offset_downscale_factor: int = 2, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(DeformableSwinTransformerBlock, self).__init__( in_channels=in_channels, @@ -594,7 +598,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout=dropout, dropout_attention=dropout_attention, dropout_path=dropout_path, - sequential_self_attention=sequential_self_attention + sequential_self_attention=sequential_self_attention, + norm_layer=norm_layer ) # Save parameter self.offset_downscale_factor: int = offset_downscale_factor @@ -684,14 +689,16 @@ class PatchMerging(nn.Module): Args: in_channels (int): Number of input channels + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. """ def __init__(self, - in_channels: int) -> None: + in_channels: int, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(PatchMerging, self).__init__() # Init normalization - self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels) + self.normalization: nn.Module = norm_layer(normalized_shape=4 * in_channels) # Init linear mapping self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels, bias=False) @@ -728,12 +735,14 @@ class PatchEmbedding(nn.Module): out_channels (int): Number of output channels patch_size (int): Patch size to be utilized image_size (int): Image size to be used + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, in_channels: int = 3, out_channels: int = 96, - patch_size: int = 4) -> None: + patch_size: int = 4, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(PatchEmbedding, self).__init__() # Save parameters @@ -743,7 +752,7 @@ class PatchEmbedding(nn.Module): kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)) # Init layer normalization - self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels) + self.normalization: nn.Module = norm_layer(normalized_shape=out_channels) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -777,6 +786,7 @@ class SwinTransformerStage(nn.Module): dropout (float): Dropout in input mapping dropout_attention (float): Dropout rate of attention map dropout_path (float): Dropout in main path + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm use_checkpoint (bool): If true checkpointing is utilized sequential_self_attention (bool): If true sequential self-attention is performed use_deformable_block (bool): If true deformable block is used @@ -803,7 +813,8 @@ class SwinTransformerStage(nn.Module): self.use_checkpoint: bool = use_checkpoint self.downscale: bool = downscale # Init downsampling - self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity() + self.downsample: nn.Module = PatchMerging(in_channels=in_channels, norm_layer=norm_layer) \ + if downscale else nn.Identity() # Update resolution and channels self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \ if downscale else input_resolution @@ -821,7 +832,8 @@ class SwinTransformerStage(nn.Module): dropout=dropout, dropout_attention=dropout_attention, dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path, - sequential_self_attention=sequential_self_attention) + sequential_self_attention=sequential_self_attention, + norm_layer=norm_layer) for index in range(depth)]) def update_resolution(self, @@ -914,7 +926,7 @@ class SwinTransformerV2CR(nn.Module): self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, - patch_size=patch_size) + patch_size=patch_size, norm_layer=norm_layer) # Compute patch resolution patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size) # Path dropout dependent on depth @@ -937,7 +949,8 @@ class SwinTransformerV2CR(nn.Module): dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], use_checkpoint=use_checkpoint, sequential_self_attention=sequential_self_attention, - use_deformable_block=use_deformable_block and (index > 0) + use_deformable_block=use_deformable_block and (index > 0), + norm_layer=norm_layer )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) @@ -1165,17 +1178,3 @@ def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs): model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2), num_heads=(16, 32, 64, 128), **kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs) - - -if __name__ == '__main__': - model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False) - model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_small_patch4_window12_384(pretrained=False) - model = swin_v2_cr_small_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_base_patch4_window12_384(pretrained=False) - model = swin_v2_cr_base_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_large_patch4_window12_384(pretrained=False) - model = swin_v2_cr_large_patch4_window7_224(pretrained=False) From 29add820ac0a7f0a11fc3b2648ab697a891f4e00 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 00:46:48 +0100 Subject: [PATCH 16/21] Refactor (back to relative imports) --- timm/models/swin_transformer_v2_cr.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 6df1ff9b..aeb713a5 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -22,15 +22,10 @@ import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -# from .helpers import build_model_with_cfg, overlay_external_default_cfg -# from .vision_transformer import checkpoint_filter_fn -# from .registry import register_model -# from .layers import DropPath, Mlp - -from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg -from timm.models.vision_transformer import checkpoint_filter_fn -from timm.models.registry import register_model -from timm.models.layers import DropPath, Mlp +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .vision_transformer import checkpoint_filter_fn +from .registry import register_model +from .layers import DropPath, Mlp _logger = logging.getLogger(__name__) From 67d140446bd0a23fcb433819902cf0c51a1bfb80 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 22:28:05 +0100 Subject: [PATCH 17/21] Fix bug in classification head --- timm/models/swin_transformer_v2_cr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index aeb713a5..033ad694 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -918,7 +918,7 @@ class SwinTransformerV2CR(nn.Module): self.patch_size: int = patch_size self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size - self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) + self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size, norm_layer=norm_layer) @@ -1038,7 +1038,7 @@ class SwinTransformerV2CR(nn.Module): for stage in self.stages: output: torch.Tensor = stage(output) # Perform average pooling - output: torch.Tensor = self.average_pool(output) + output: torch.Tensor = self.average_pool(output).flatten(start_dim=1) # Predict classification classification: torch.Tensor = self.head(output) return classification From c6e4b7895a7dbcd9b98396cbef383dd1c72b0ad3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Feb 2022 17:28:52 -0800 Subject: [PATCH 18/21] Swin V2 CR impl refactor. * reformat and change some naming so closer to existing timm vision transformers * remove typing that wasn't adding clarity (or causing torchscript issues) * support non-square windows * auto window size adjust from image size * post-norm + main-branch no --- timm/models/__init__.py | 1 + timm/models/swin_transformer_v2_cr.py | 1583 +++++++++++-------------- 2 files changed, 672 insertions(+), 912 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 2ef4918a..5bdb0867 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -38,6 +38,7 @@ from .selecsls import * from .senet import * from .sknet import * from .swin_transformer import * +from .swin_transformer_v2_cr import * from .tnt import * from .tresnet import * from .twins import * diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 033ad694..bad5488d 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -1,10 +1,24 @@ """ Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - https://arxiv.org/pdf/2111.09883 Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below -Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +This implementation is experimental and subject to change in manners that will break weight compat: +* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads? + * currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head) +* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at + GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial. +* num_heads per stage is not detailed for Huge and Giant model variants +* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts + +Noteworthy additions over official Swin v1: +* MLP relative position embedding is looking promising and adapts to different image/window sizes +* This impl has been designed to allow easy change of image size with matching window size changes +* Non-square image size and window size are supported + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman """ # -------------------------------------------------------- # Swin Transformer V2 reimplementation @@ -13,19 +27,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Written by Christoph Reich # -------------------------------------------------------- import logging +import math from copy import deepcopy from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .vision_transformer import checkpoint_filter_fn +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model -from .layers import DropPath, Mlp +from .vision_transformer import checkpoint_filter_fn _logger = logging.getLogger(__name__) @@ -33,1015 +48,710 @@ _logger = logging.getLogger(__name__) def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, } default_cfgs = { # patch models (my experiments) - 'swin_v2_cr_tiny_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_tiny_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), - - 'swin_v2_cr_small_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_small_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), - - 'swin_v2_cr_base_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_base_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), - - 'swin_v2_cr_large_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_large_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), - - 'swin_v2_cr_huge_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_huge_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), - - 'swin_v2_cr_giant_patch4_window12_384': _cfg( - url="", - input_size=(3, 384, 384), crop_pct=1.0), - - 'swin_v2_cr_giant_patch4_window7_224': _cfg( - url="", - input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_tiny_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_tiny_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_small_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_small_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_base_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_base_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_large_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_large_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_huge_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_huge_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_giant_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_giant_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), } -def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: - """ Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). +def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """ + return x.permute(0, 2, 3, 1) - Args: - input (torch.Tensor): Input tensor of the shape (B, C, H, W) - - Returns: - output (torch.Tensor): Permuted tensor of the shape (B, H, W, C) - """ - output: torch.Tensor = input.permute(0, 2, 3, 1) - return output +def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """ + return x.permute(0, 3, 1, 2) -def bhwc_to_bchw(input: torch.Tensor) -> torch.Tensor: - """ Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). +def window_partition(x, window_size: Tuple[int, int]): + """ Args: - input (torch.Tensor): Input tensor of the shape (B, H, W, C) + x: (B, H, W, C) + window_size (int): window size Returns: - output (torch.Tensor): Permuted tensor of the shape (B, C, H, W) + windows: (num_windows*B, window_size, window_size, C) """ - output: torch.Tensor = input.permute(0, 3, 1, 2) - return output - + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows -def unfold(input: torch.Tensor, - window_size: int) -> torch.Tensor: - """ Unfolds (non-overlapping) a given feature map by the given window size (stride = window size). - Args: - input (torch.Tensor): Input feature map of the shape (B, C, H, W) - window_size (int): Window size to be applied - - Returns: - output (torch.Tensor): Unfolded tensor of the shape [B * windows, C, window size, window size] +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: tuple[int, int], img_size: tuple[int, int]): """ - # Get original shape - _, channels, height, width = input.shape # type: int, int, int, int - # Unfold input - output: torch.Tensor = input.unfold(dimension=3, size=window_size, step=window_size) \ - .unfold(dimension=2, size=window_size, step=window_size) - # Reshape to (B * windows, C, window size, window size) - output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size, window_size) - return output - - -def fold(input: torch.Tensor, - window_size: int, - height: int, - width: int) -> torch.Tensor: - """ Folds a tensor of windows again to a 4D feature map. - Args: - input (torch.Tensor): Input feature map of the shape (B, C, H, W) - window_size (int): Window size of the unfold operation - height (int): Height of the feature map - width (int): Width of the feature map + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size Returns: - output (torch.Tensor): Folded output tensor of the shape (B, C, H, W) + x: (B, H, W, C) """ - # Get channels of windows - channels: int = input.shape[1] - # Get original batch size - batch_size: int = int(input.shape[0] // (height * width // window_size // window_size)) - # Reshape input to (B, C, H, W) - output: torch.Tensor = input.view(batch_size, height // window_size, width // window_size, channels, - window_size, window_size) - output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width) - return output + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x class WindowMultiHeadAttention(nn.Module): - r""" This class implements window-based Multi-Head-Attention with log-spaced continuous position bias. + r"""This class implements window-based Multi-Head-Attention with log-spaced continuous position bias. Args: - in_features (int): Number of input features + dim (int): Number of input features window_size (int): Window size - number_of_heads (int): Number of attention heads - dropout_attention (float): Dropout rate of attention map - dropout_projection (float): Dropout rate after projection - meta_network_hidden_features (int): Number of hidden features in the two layer MLP meta network - sequential_self_attention (bool): If true sequential self-attention is performed + num_heads (int): Number of attention heads + drop_attn (float): Dropout rate of attention map + drop_proj (float): Dropout rate after projection + meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network + sequential_attn (bool): If true sequential self-attention is performed """ - def __init__(self, - in_features: int, - window_size: int, - number_of_heads: int, - dropout_attention: float = 0., - dropout_projection: float = 0., - meta_network_hidden_features: int = 256, - sequential_self_attention: bool = False) -> None: - # Call super constructor + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int], + drop_attn: float = 0.0, + drop_proj: float = 0.0, + meta_hidden_dim: int = 384, # FIXME what's the optimal value? + sequential_attn: bool = False, + ) -> None: super(WindowMultiHeadAttention, self).__init__() - # Check parameter - assert (in_features % number_of_heads) == 0, \ - "The number of input features (in_features) are not divisible by the number of heads (number_of_heads)." - # Save parameters - self.in_features: int = in_features - self.window_size: int = window_size - self.number_of_heads: int = number_of_heads - self.sequential_self_attention: bool = sequential_self_attention - # Init query, key and value mapping as a single layer - self.mapping_qkv: nn.Module = nn.Linear(in_features=in_features, out_features=in_features * 3, bias=True) - # Init attention dropout - self.attention_dropout: nn.Module = nn.Dropout(dropout_attention) - # Init projection mapping - self.projection: nn.Module = nn.Linear(in_features=in_features, out_features=in_features, bias=True) - # Init projection dropout - self.projection_dropout: nn.Module = nn.Dropout(dropout_projection) - # Init meta network for positional encodings - self.meta_network: nn.Module = nn.Sequential( - nn.Linear(in_features=2, out_features=meta_network_hidden_features, bias=True), - nn.ReLU(inplace=True), - nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads, bias=True)) - # Init tau - self.register_parameter("tau", torch.nn.Parameter(torch.ones(1, number_of_heads, 1, 1))) - # Init pair-wise relative positions (log-spaced) - self.__make_pair_wise_relative_positions() - - def __make_pair_wise_relative_positions(self) -> None: - """ Method initializes the pair-wise relative positions to compute the positional biases.""" - indexes: torch.Tensor = torch.arange(self.window_size, device=self.tau.device) - coordinates: torch.Tensor = torch.stack(torch.meshgrid([indexes, indexes]), dim=0) - coordinates: torch.Tensor = torch.flatten(coordinates, start_dim=1) - relative_coordinates: torch.Tensor = coordinates[:, :, None] - coordinates[:, None, :] - relative_coordinates: torch.Tensor = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() - relative_coordinates_log: torch.Tensor = torch.sign(relative_coordinates) \ - * torch.log(1. + relative_coordinates.abs()) - self.register_buffer("relative_coordinates_log", relative_coordinates_log) - - def update_resolution(self, - new_window_size: int, - **kwargs: Any) -> None: - """ Method updates the window size and so the pair-wise relative positions + assert dim % num_heads == 0, \ + "The number of input features (in_features) are not divisible by the number of heads (num_heads)." + self.in_features: int = dim + self.window_size: Tuple[int, int] = window_size + self.num_heads: int = num_heads + self.sequential_attn: bool = sequential_attn + + self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True) + self.attn_drop = nn.Dropout(drop_attn) + self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True) + self.proj_drop = nn.Dropout(drop_proj) + # meta network for positional encodings + self.meta_mlp = Mlp( + 2, # x, y + hidden_features=meta_hidden_dim, + out_features=num_heads, + act_layer=nn.ReLU, + drop=0. # FIXME should we add stochasticity? + ) + self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) + self._make_pair_wise_relative_positions() + + def _make_pair_wise_relative_positions(self) -> None: + """Method initializes the pair-wise relative positions to compute the positional biases.""" + device = self.tau.device + coordinates = torch.stack(torch.meshgrid([ + torch.arange(self.window_size[0], device=device), + torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) + relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] + relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() + relative_coordinates_log = torch.sign(relative_coordinates) * torch.log( + 1.0 + relative_coordinates.abs()) + self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False) + + def update_input_size(self, new_window_size: int, **kwargs: Any) -> None: + """Method updates the window size and so the pair-wise relative positions Args: new_window_size (int): New window size kwargs (Any): Unused """ - # Set new window size + # Set new window size and new pair-wise relative positions self.window_size: int = new_window_size - # Make new pair-wise relative positions - self.__make_pair_wise_relative_positions() + self._make_pair_wise_relative_positions() - def __get_relative_positional_encodings(self) -> torch.Tensor: - """ Method computes the relative positional encodings + def _relative_positional_encodings(self) -> torch.Tensor: + """Method computes the relative positional encodings Returns: relative_position_bias (torch.Tensor): Relative positional encodings (1, number of heads, window size ** 2, window size ** 2) """ - relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log) - relative_position_bias: torch.Tensor = relative_position_bias.permute(1, 0) - relative_position_bias: torch.Tensor = relative_position_bias.reshape(self.number_of_heads, - self.window_size * self.window_size, - self.window_size * self.window_size) - relative_position_bias: torch.Tensor = relative_position_bias.unsqueeze(0) + window_area = self.window_size[0] * self.window_size[1] + relative_position_bias = self.meta_mlp(self.relative_coordinates_log) + relative_position_bias = relative_position_bias.transpose(1, 0).reshape( + self.num_heads, window_area, window_area + ) + relative_position_bias = relative_position_bias.unsqueeze(0) return relative_position_bias - def __self_attention(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - batch_size_windows: int, - tokens: int, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ This function performs standard (non-sequential) scaled cosine self-attention. - - Args: - query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] - key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] - value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) - batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows - tokens (int): Number of tokens in the input - mask (Optional[torch.Tensor]): Attention mask for the shift case - - Returns: - output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] + def _forward_sequential( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ - # Compute attention map with scaled cosine attention - attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \ - / torch.maximum(torch.norm(query, dim=-1, keepdim=True) - * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1), - torch.tensor(1e-06, device=query.device, dtype=query.dtype)) - attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01) - # Apply relative positional encodings - attention_map: torch.Tensor = attention_map + self.__get_relative_positional_encodings() - # Apply mask if utilized - if mask is not None: - number_of_windows: int = mask.shape[0] - attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, number_of_windows, - self.number_of_heads, tokens, tokens) - attention_map: torch.Tensor = attention_map + mask.unsqueeze(1).unsqueeze(0) - attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens, tokens) - attention_map: torch.Tensor = attention_map.softmax(dim=-1) - # Perform attention dropout - attention_map: torch.Tensor = self.attention_dropout(attention_map) - # Apply attention map and reshape - output: torch.Tensor = torch.einsum("bhal, bhlv -> bhav", attention_map, value) - output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) - return output - - def __sequential_self_attention(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - batch_size_windows: int, - tokens: int, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """ This function performs sequential scaled cosine self-attention. + """ + # FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory) + assert False, "not implemented" + + def _forward_batch( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """This function performs standard (non-sequential) scaled cosine self-attention. + """ + Bw, L, C = x.shape - Args: - query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] - key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] - value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) - batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows - tokens (int): Number of tokens in the input - mask (Optional[torch.Tensor]): Attention mask for the shift case + qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + query, key, value = qkv.unbind(0) - Returns: - output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] - """ - # Init output tensor - output: torch.Tensor = torch.ones_like(query) - # Compute relative positional encodings fist - relative_position_bias: torch.Tensor = self.__get_relative_positional_encodings() - # Iterate over query and key tokens - for token_index_query in range(tokens): - # Compute attention map with scaled cosine attention - attention_map: torch.Tensor = \ - torch.einsum("bhd, bhkd -> bhk", query[:, :, token_index_query], key) \ - / torch.maximum(torch.norm(query[:, :, token_index_query], dim=-1, keepdim=True) - * torch.norm(key, dim=-1, keepdim=False), - torch.tensor(1e-06, device=query.device, dtype=query.dtype)) - attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)[..., 0] - # Apply positional encodings - attention_map: torch.Tensor = attention_map + relative_position_bias[..., token_index_query, :] + # compute attention map with scaled cosine attention + denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) + attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) + attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) + attn = attn + self._relative_positional_encodings() + if mask is not None: # Apply mask if utilized - if mask is not None: - number_of_windows: int = mask.shape[0] - attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, - number_of_windows, self.number_of_heads, 1, - tokens) - attention_map: torch.Tensor = attention_map \ - + mask.unsqueeze(1).unsqueeze(0)[..., token_index_query, :].unsqueeze(3) - attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens) - attention_map: torch.Tensor = attention_map.softmax(dim=-1) - # Perform attention dropout - attention_map: torch.Tensor = self.attention_dropout(attention_map) - # Apply attention map and reshape - output[:, :, token_index_query] = torch.einsum("bhl, bhlv -> bhv", attention_map, value) - output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) - return output - - def forward(self, - input: torch.Tensor, - mask: Optional[torch.Tensor] = None) -> torch.Tensor: + num_win: int = mask.shape[0] + attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L) + attn = attn + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, L, L) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass. - Args: - input (torch.Tensor): Input tensor of the shape (B * windows, C, H, W) + x (torch.Tensor): Input tensor of the shape (B * windows, N, C) mask (Optional[torch.Tensor]): Attention mask for the shift case Returns: - output (torch.Tensor): Output tensor of the shape [B * windows, C, H, W] + Output tensor of the shape [B * windows, N, C] """ - # Save original shape - batch_size_windows, channels, height, width = input.shape # type: int, int, int, int - tokens: int = height * width - # Reshape input to (B * windows, tokens (height * width), C) - input: torch.Tensor = input.reshape(batch_size_windows, channels, tokens).permute(0, 2, 1) - # Perform query, key, and value mapping - query_key_value: torch.Tensor = self.mapping_qkv(input) - query_key_value: torch.Tensor = query_key_value.view(batch_size_windows, tokens, 3, self.number_of_heads, - channels // self.number_of_heads).permute(2, 0, 3, 1, 4) - query, key, value = query_key_value[0], query_key_value[1], query_key_value[2] - # Perform attention - if self.sequential_self_attention: - output: torch.Tensor = self.__sequential_self_attention(query=query, key=key, value=value, - batch_size_windows=batch_size_windows, - tokens=tokens, - mask=mask) + if self.sequential_attn: + return self._forward_sequential(x, mask) else: - output: torch.Tensor = self.__self_attention(query=query, key=key, value=value, - batch_size_windows=batch_size_windows, tokens=tokens, - mask=mask) - # Perform linear mapping and dropout - output: torch.Tensor = self.projection_dropout(self.projection(output)) - # Reshape output to original shape [B * windows, C, H, W] - output: torch.Tensor = output.permute(0, 2, 1).view(batch_size_windows, channels, height, width) - return output + return self._forward_batch(x, mask) class SwinTransformerBlock(nn.Module): - r""" This class implements the Swin transformer block. + r"""This class implements the Swin transformer block. Args: - in_channels (int): Number of input channels - input_resolution (Tuple[int, int]): Input resolution - number_of_heads (int): Number of attention heads to be utilized - window_size (int): Window size to be utilized + dim (int): Number of input channels + num_heads (int): Number of attention heads to be utilized + feat_size (Tuple[int, int]): Input resolution + window_size (Tuple[int, int]): Window size to be utilized shift_size (int): Shifting size to be used - ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - dropout (float): Dropout in input mapping - dropout_attention (float): Dropout rate of attention map - dropout_path (float): Dropout in main path - sequential_self_attention (bool): If true sequential self-attention is performed + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + drop (float): Dropout in input mapping + drop_attn (float): Dropout rate of attention map + drop_path (float): Dropout in main path + extra_norm (bool): Insert extra norm on 'main' branch if True + sequential_attn (bool): If true sequential self-attention is performed norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ - def __init__(self, - in_channels: int, - input_resolution: Tuple[int, int], - number_of_heads: int, - window_size: int = 7, - shift_size: int = 0, - ff_feature_ratio: int = 4, - dropout: float = 0.0, - dropout_attention: float = 0.0, - dropout_path: float = 0.0, - sequential_self_attention: bool = False, - norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: - # Call super constructor + def __init__( + self, + dim: int, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + shift_size: Tuple[int, int] = (0, 0), + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: float = 0.0, + extra_norm: bool = False, + sequential_attn: bool = False, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: super(SwinTransformerBlock, self).__init__() - # Save parameters - self.in_channels: int = in_channels - self.input_resolution: Tuple[int, int] = input_resolution - # Catch case if resolution is smaller than the window size - if min(self.input_resolution) <= window_size: - self.window_size: int = min(self.input_resolution) - self.shift_size: int = 0 - self.make_windows: bool = False - else: - self.window_size: int = window_size - self.shift_size: int = shift_size - self.make_windows: bool = True - # Init normalization layers - self.normalization_1: nn.Module = norm_layer(normalized_shape=in_channels) - self.normalization_2: nn.Module = norm_layer(normalized_shape=in_channels) - # Init window attention module - self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention( - in_features=in_channels, + self.dim: int = dim + self.feat_size: Tuple[int, int] = feat_size + self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) + self.window_area = self.window_size[0] * self.window_size[1] + + # attn branch + self.attn = WindowMultiHeadAttention( + dim=dim, + num_heads=num_heads, window_size=self.window_size, - number_of_heads=number_of_heads, - dropout_attention=dropout_attention, - dropout_projection=dropout, - sequential_self_attention=sequential_self_attention) - # Init dropout layer - self.dropout: nn.Module = DropPath(drop_prob=dropout_path) if dropout_path > 0. else nn.Identity() - # Init feed-forward network - self.feed_forward_network: nn.Module = Mlp(in_features=in_channels, - hidden_features=int(in_channels * ff_feature_ratio), - drop=dropout, - out_features=in_channels) - # Make attention mask - self.__make_attention_mask() - - def __make_attention_mask(self) -> None: - """ Method generates the attention mask used in shift case. """ + drop_attn=drop_attn, + drop_proj=drop, + sequential_attn=sequential_attn, + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() + + # mlp branch + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + drop=drop, + out_features=dim, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() + + # extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?) + self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() + + self._make_attention_mask() + + def _calc_window_shift(self, target_window_size): + window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] + shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _make_attention_mask(self) -> None: + """Method generates the attention mask used in shift case.""" # Make masks for shift case - if self.shift_size > 0: - height, width = self.input_resolution # type: int, int - mask: torch.Tensor = torch.zeros(height, width, device=self.window_attention.tau.device) - height_slices: Tuple = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - width_slices: Tuple = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - counter: int = 0 - for height_slice in height_slices: - for width_slice in width_slices: - mask[height_slice, width_slice] = counter - counter += 1 - mask_windows: torch.Tensor = unfold(mask[None, None], self.window_size) - mask_windows: torch.Tensor = mask_windows.reshape(-1, self.window_size * self.window_size) - attention_mask: Optional[torch.Tensor] = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask != 0, float(-100.0)) - attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask == 0, float(0.0)) + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: - attention_mask: Optional[torch.Tensor] = None - # Save mask - self.register_buffer("attention_mask", attention_mask) + attn_mask = None + self.register_buffer("attn_mask", attn_mask, persistent=False) - def update_resolution(self, - new_window_size: int, - new_input_resolution: Tuple[int, int]) -> None: - """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: + """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. Args: new_window_size (int): New window size - new_input_resolution (Tuple[int, int]): New input resolution + new_feat_size (Tuple[int, int]): New input resolution """ # Update input resolution - self.input_resolution: Tuple[int, int] = new_input_resolution - # Catch case if resolution is smaller than the window size - if min(self.input_resolution) <= new_window_size: - self.window_size: int = min(self.input_resolution) - self.shift_size: int = 0 - self.make_windows: bool = False + self.feat_size: Tuple[int, int] = new_feat_size + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size)) + self.window_area = self.window_size[0] * self.window_size[1] + self.attn.update_input_size(new_window_size=self.window_size) + self._make_attention_mask() + + def _shifted_window_attn(self, x): + H, W = self.feat_size + B, L, C = x.shape + x = x.view(B, H, W, C) + + # cyclic shift + if any(self.shift_size): + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: - self.window_size: int = new_window_size - self.shift_size: int = self.shift_size - self.make_windows: bool = True - # Update attention mask - self.__make_attention_mask() - # Update attention module - self.window_attention.update_resolution(new_window_size=new_window_size) - - def forward(self, - input: torch.Tensor) -> torch.Tensor: - """ Forward pass. + shifted_x = x - Args: - input (torch.Tensor): Input tensor of the shape [B, C, H, W] + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) - Returns: - output (torch.Tensor): Output tensor of the shape [B, C, H, W] - """ - # Save shape - batch_size, channels, height, width = input.shape # type: int, int, int, int - # Shift input if utilized - if self.shift_size > 0: - output_shift: torch.Tensor = torch.roll(input=input, shifts=(-self.shift_size, -self.shift_size), - dims=(-1, -2)) - else: - output_shift: torch.Tensor = input - # Make patches - output_patches: torch.Tensor = unfold(input=output_shift, window_size=self.window_size) \ - if self.make_windows else output_shift - # Perform window attention - output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask) - # Merge patches - output_merge: torch.Tensor = fold(input=output_attention, window_size=self.window_size, height=height, - width=width) if self.make_windows else output_attention - # Reverse shift if utilized - if self.shift_size > 0: - output_shift: torch.Tensor = torch.roll(input=output_merge, shifts=(self.shift_size, self.shift_size), - dims=(-1, -2)) - else: - output_shift: torch.Tensor = output_merge - # Perform normalization - output_normalize: torch.Tensor = self.normalization_1(output_shift.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - # Skip connection - output_skip: torch.Tensor = self.dropout(output_normalize) + input - # Feed forward network, normalization and skip connection - output_feed_forward: torch.Tensor = self.feed_forward_network( - output_skip.view(batch_size, channels, -1).permute(0, 2, 1)).permute(0, 2, 1) - output_feed_forward: torch.Tensor = output_feed_forward.view(batch_size, channels, height, width) - output_normalize: torch.Tensor = bhwc_to_bchw(self.normalization_2(bchw_to_bhwc(output_feed_forward))) - output: torch.Tensor = output_skip + self.dropout(output_normalize) - return output - - -class DeformableSwinTransformerBlock(SwinTransformerBlock): - r""" This class implements a deformable version of the Swin Transformer block. - Inspired by: https://arxiv.org/pdf/2201.00520 + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C - Args: - in_channels (int): Number of input channels - input_resolution (Tuple[int, int]): Input resolution - number_of_heads (int): Number of attention heads to be utilized - window_size (int): Window size to be utilized - shift_size (int): Shifting size to be used - ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - dropout (float): Dropout in input mapping - dropout_attention (float): Dropout rate of attention map - dropout_path (float): Dropout in main path - sequential_self_attention (bool): If true sequential self-attention is performed - offset_downscale_factor (int): Downscale factor of offset network - norm_layer (Type[nn.Module]): Type of normalization layer to be utilized - """ + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C - def __init__(self, - in_channels: int, - input_resolution: Tuple[int, int], - number_of_heads: int, - window_size: int = 7, - shift_size: int = 0, - ff_feature_ratio: int = 4, - dropout: float = 0.0, - dropout_attention: float = 0.0, - dropout_path: float = 0.0, - sequential_self_attention: bool = False, - offset_downscale_factor: int = 2, - norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: - # Call super constructor - super(DeformableSwinTransformerBlock, self).__init__( - in_channels=in_channels, - input_resolution=input_resolution, - number_of_heads=number_of_heads, - window_size=window_size, - shift_size=shift_size, - ff_feature_ratio=ff_feature_ratio, - dropout=dropout, - dropout_attention=dropout_attention, - dropout_path=dropout_path, - sequential_self_attention=sequential_self_attention, - norm_layer=norm_layer - ) - # Save parameter - self.offset_downscale_factor: int = offset_downscale_factor - self.number_of_heads: int = number_of_heads - # Make default offsets - self.__make_default_offsets() - # Init offset network - self.offset_network: nn.Module = nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor, - padding=3, groups=in_channels, bias=True), - nn.GELU(), - nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1, - padding=0, bias=True) - ) + # reverse cyclic shift + if any(self.shift_size): + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x - def __make_default_offsets(self) -> None: - """ Method generates the default sampling grid (inspired by kornia) """ - # Init x and y coordinates - x: torch.Tensor = torch.linspace(0, self.input_resolution[1] - 1, self.input_resolution[1], - device=self.window_attention.tau.device) - y: torch.Tensor = torch.linspace(0, self.input_resolution[0] - 1, self.input_resolution[0], - device=self.window_attention.tau.device) - # Normalize coordinates to a range of [-1, 1] - x: torch.Tensor = (x / (self.input_resolution[1] - 1) - 0.5) * 2 - y: torch.Tensor = (y / (self.input_resolution[0] - 1) - 0.5) * 2 - # Make grid [2, height, width] - grid: torch.Tensor = torch.stack(torch.meshgrid([x, y])).transpose(1, 2) - # Reshape grid to [1, height, width, 2] - grid: torch.Tensor = grid.unsqueeze(dim=0).permute(0, 2, 3, 1) - # Register in module - self.register_buffer("default_grid", grid) - - def update_resolution(self, - new_window_size: int, - new_input_resolution: Tuple[int, int]) -> None: - """ Method updates the window size and so the pair-wise relative positions. + x = x.view(B, L, C) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. Args: - new_window_size (int): New window size - new_input_resolution (Tuple[int, int]): New input resolution - """ - # Update resolution and window size - super(DeformableSwinTransformerBlock, self).update_resolution(new_window_size=new_window_size, - new_input_resolution=new_input_resolution) - # Update default sampling grid - self.__make_default_offsets() - - def forward(self, - input: torch.Tensor) -> torch.Tensor: - """ Forward pass - Args: - input (torch.Tensor): Input tensor of the shape [B, C, H, W] + x (torch.Tensor): Input tensor of the shape [B, C, H, W] Returns: output (torch.Tensor): Output tensor of the shape [B, C, H, W] """ - # Get input shape - batch_size, channels, height, width = input.shape - # Compute offsets of the shape [batch size, 2, height / r, width / r] - offsets: torch.Tensor = self.offset_network(input) - # Upscale offsets to the shape [batch size, 2 * number of heads, height, width] - offsets: torch.Tensor = F.interpolate(input=offsets, - size=(height, width), mode="bilinear", align_corners=True) - # Reshape offsets to [batch size, number of heads, height, width, 2] - offsets: torch.Tensor = offsets.reshape(batch_size, -1, 2, height, width).permute(0, 1, 3, 4, 2) - # Flatten batch size and number of heads and apply tanh - offsets: torch.Tensor = offsets.view(-1, height, width, 2).tanh() - # Cast offset grid to input data type - if input.dtype != self.default_grid.dtype: - self.default_grid = self.default_grid.type(input.dtype) - # Construct offset grid - offset_grid: torch.Tensor = self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0) + offsets - # Reshape input to [batch size * number of heads, channels / number of heads, height, width] - input: torch.Tensor = input.view(batch_size, self.number_of_heads, channels // self.number_of_heads, height, - width).flatten(start_dim=0, end_dim=1) - # Apply sampling grid - input_resampled: torch.Tensor = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1), - mode="bilinear", align_corners=True, padding_mode="reflection") - # Reshape resampled tensor again to [batch size, channels, height, width] - input_resampled: torch.Tensor = input_resampled.view(batch_size, channels, height, width) - output: torch.Tensor = super(DeformableSwinTransformerBlock, self).forward(input=input_resampled) - return output + # NOTE post-norm branches (op -> norm -> drop) + x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant) + return x class PatchMerging(nn.Module): """ This class implements the patch merging as a strided convolution with a normalization before. - Args: - in_channels (int): Number of input channels + dim (int): Number of input channels norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. """ - def __init__(self, - in_channels: int, - norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: - # Call super constructor + def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: super(PatchMerging, self).__init__() - # Init normalization - self.normalization: nn.Module = norm_layer(normalized_shape=4 * in_channels) - # Init linear mapping - self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels, - bias=False) - - def forward(self, - input: torch.Tensor) -> torch.Tensor: - """ Forward pass. + self.norm = norm_layer(4 * dim) + self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False) - Args: - input (torch.Tensor): Input tensor of the shape [B, C, H, W] - - Returns: - output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] - """ - # Get original shape - batch_size, channels, height, width = input.shape # type: int, int, int, int - # Reshape input to [batch size, in channels, height, width] - input: torch.Tensor = bchw_to_bhwc(input) - # Unfold input - input: torch.Tensor = input.unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) - input: torch.Tensor = input.reshape(batch_size, input.shape[1], input.shape[2], -1) - # Normalize input - input: torch.Tensor = self.normalization(input) - # Perform linear mapping - output: torch.Tensor = bhwc_to_bchw(self.linear_mapping(input)) - return output - - -class PatchEmbedding(nn.Module): - """ Module embeds a given image into patch embeddings. - - Args: - in_channels (int): Number of input channels - out_channels (int): Number of output channels - patch_size (int): Patch size to be utilized - image_size (int): Image size to be used - norm_layer (Type[nn.Module]): Type of normalization layer to be utilized - """ - - def __init__(self, - in_channels: int = 3, - out_channels: int = 96, - patch_size: int = 4, - norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: - # Call super constructor - super(PatchEmbedding, self).__init__() - # Save parameters - self.out_channels: int = out_channels - # Init linear embedding as a convolution - self.linear_embedding: nn.Module = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, - kernel_size=(patch_size, patch_size), - stride=(patch_size, patch_size)) - # Init layer normalization - self.normalization: nn.Module = norm_layer(normalized_shape=out_channels) - - def forward(self, - input: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. - Args: - input (torch.Tensor): Input image of the shape (B, C_in, H, W) - + x (torch.Tensor): Input tensor of the shape [B, C, H, W] Returns: - embedding (torch.Tensor): Embedding of the shape (B, C_out, H / patch size, W / patch size) + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ - # Perform linear embedding - embedding: torch.Tensor = self.linear_embedding(input) - # Perform normalization - embedding: torch.Tensor = bhwc_to_bchw(self.normalization(bchw_to_bhwc(embedding))) - return embedding + x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) + x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl + x = self.norm(x) + x = bhwc_to_bchw(self.reduction(x)) + return x + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x class SwinTransformerStage(nn.Module): - r""" This class implements a stage of the Swin transformer including multiple layers. + r"""This class implements a stage of the Swin transformer including multiple layers. Args: - in_channels (int): Number of input channels + embed_dim (int): Number of input channels depth (int): Depth of the stage (number of layers) downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) - input_resolution (Tuple[int, int]): Input resolution - number_of_heads (int): Number of attention heads to be utilized + feat_size (Tuple[int, int]): input feature map size (H, W) + num_heads (int): Number of attention heads to be utilized window_size (int): Window size to be utilized - shift_size (int): Shifting size to be used - ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - dropout (float): Dropout in input mapping - dropout_attention (float): Dropout rate of attention map - dropout_path (float): Dropout in main path + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + drop (float): Dropout in input mapping + drop_attn (float): Dropout rate of attention map + drop_path (float): Dropout in main path norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm - use_checkpoint (bool): If true checkpointing is utilized - sequential_self_attention (bool): If true sequential self-attention is performed - use_deformable_block (bool): If true deformable block is used + grad_checkpointing (bool): If true checkpointing is utilized + extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks + sequential_attn (bool): If true sequential self-attention is performed """ - def __init__(self, - in_channels: int, - depth: int, - downscale: bool, - input_resolution: Tuple[int, int], - number_of_heads: int, - window_size: int = 7, - ff_feature_ratio: int = 4, - dropout: float = 0.0, - dropout_attention: float = 0.0, - dropout_path: Union[List[float], float] = 0.0, - norm_layer: Type[nn.Module] = nn.LayerNorm, - use_checkpoint: bool = False, - sequential_self_attention: bool = False, - use_deformable_block: bool = False) -> None: - # Call super constructor + def __init__( + self, + embed_dim: int, + depth: int, + downscale: bool, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: Union[List[float], float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + grad_checkpointing: bool = False, + extra_norm_period: int = 0, + sequential_attn: bool = False, + ) -> None: super(SwinTransformerStage, self).__init__() - # Save parameters - self.use_checkpoint: bool = use_checkpoint self.downscale: bool = downscale - # Init downsampling - self.downsample: nn.Module = PatchMerging(in_channels=in_channels, norm_layer=norm_layer) \ - if downscale else nn.Identity() - # Update resolution and channels - self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \ - if downscale else input_resolution - in_channels = in_channels * 2 if downscale else in_channels - # Get block - block = DeformableSwinTransformerBlock if use_deformable_block else SwinTransformerBlock - # Init blocks - self.blocks: nn.ModuleList = nn.ModuleList([ - block(in_channels=in_channels, - input_resolution=self.input_resolution, - number_of_heads=number_of_heads, - window_size=window_size, - shift_size=0 if ((index % 2) == 0) else window_size // 2, - ff_feature_ratio=ff_feature_ratio, - dropout=dropout, - dropout_attention=dropout_attention, - dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path, - sequential_self_attention=sequential_self_attention, - norm_layer=norm_layer) - for index in range(depth)]) - - def update_resolution(self, - new_window_size: int, - new_input_resolution: Tuple[int, int]) -> None: - """ Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + self.grad_checkpointing: bool = grad_checkpointing + self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size + + self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() + + embed_dim = embed_dim * 2 if downscale else embed_dim + self.blocks = nn.Sequential(*[ + SwinTransformerBlock( + dim=embed_dim, + num_heads=num_heads, + feat_size=self.feat_size, + window_size=window_size, + shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), + mlp_ratio=mlp_ratio, + drop=drop, + drop_attn=drop_attn, + drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, + extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False, + sequential_attn=sequential_attn, + norm_layer=norm_layer, + ) + for index in range(depth)] + ) + + def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None: + """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. Args: new_window_size (int): New window size - new_input_resolution (Tuple[int, int]): New input resolution + new_feat_size (Tuple[int, int]): New input resolution """ - # Update resolution - self.input_resolution: Tuple[int, int] = (new_input_resolution[0] // 2, new_input_resolution[1] // 2) \ - if self.downscale else new_input_resolution - # Update resolution of each block - for block in self.blocks: # type: SwinTransformerBlock - block.update_resolution(new_window_size=new_window_size, new_input_resolution=self.input_resolution) - - def forward(self, - input: torch.Tensor) -> torch.Tensor: - """ Forward pass. + self.feat_size: Tuple[int, int] = ( + (new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size + ) + for block in self.blocks: + block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. Args: - input (torch.Tensor): Input tensor of the shape [B, C, H, W] - + x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C] Returns: output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ - # Downscale input tensor - output: torch.Tensor = self.downsample(input) - # Forward pass of each block - for block in self.blocks: # type: nn.Module + x = self.downsample(x) + B, C, H, W = x.shape + L = H * W + + x = bchw_to_bhwc(x).reshape(B, L, C) + for block in self.blocks: # Perform checkpointing if utilized - if self.use_checkpoint: - output: torch.Tensor = checkpoint.checkpoint(block, output) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(block, x) else: - output: torch.Tensor = block(output) - return output + x = block(x) + x = bhwc_to_bchw(x.reshape(B, H, W, -1)) + return x -class SwinTransformerV2CR(nn.Module): +class SwinTransformerV2Cr(nn.Module): r""" Swin Transformer V2 A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - https://arxiv.org/pdf/2111.09883 Args: img_size (Tuple[int, int]): Input resolution. + window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None + img_window_ratio (int): Window size to image size ratio. Default: 32 + patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input channels. depths (int): Depth of the stage (number of layers). num_heads (int): Number of attention heads to be utilized. embed_dim (int): Patch embedding dimension. Default: 96 num_classes (int): Number of output classes. Default: 1000 - window_size (int): Window size to be utilized. Default: 7 - patch_size (int | tuple(int)): Patch size. Default: 4 mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4 drop_rate (float): Dropout rate. Default: 0.0 attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 drop_path_rate (float): Stochastic depth rate. Default: 0.0 norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm - use_checkpoint (bool): If true checkpointing is utilized. Default: False - sequential_self_attention (bool): If true sequential self-attention is performed. Default: False - use_deformable_block (bool): If true deformable block is used. Default: False + grad_checkpointing (bool): If true checkpointing is utilized. Default: False + sequential_attn (bool): If true sequential self-attention is performed. Default: False + use_deformable (bool): If true deformable block is used. Default: False """ - def __init__(self, - img_size: Tuple[int, int], - in_chans: int, - depths: Tuple[int, ...], - num_heads: Tuple[int, ...], - embed_dim: int = 96, - num_classes: int = 1000, - window_size: int = 7, - patch_size: int = 4, - mlp_ratio: int = 4, - drop_rate: float = 0.0, - attn_drop_rate: float = 0.0, - drop_path_rate: float = 0.0, - norm_layer: Type[nn.Module] = nn.LayerNorm, - use_checkpoint: bool = False, - sequential_self_attention: bool = False, - use_deformable_block: bool = False, - **kwargs: Any) -> None: - # Call super constructor - super(SwinTransformerV2CR, self).__init__() - # Save parameters + def __init__( + self, + img_size: Tuple[int, int] = (224, 224), + patch_size: int = 4, + window_size: Optional[int] = None, + img_window_ratio: int = 32, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + grad_checkpointing: bool = False, + extra_norm_period: int = 0, + sequential_attn: bool = False, + global_pool: str = 'avg', + **kwargs: Any + ) -> None: + super(SwinTransformerV2Cr, self).__init__() + img_size = to_2tuple(img_size) + window_size = tuple([ + s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size) + self.num_classes: int = num_classes self.patch_size: int = patch_size - self.input_resolution: Tuple[int, int] = img_size + self.img_size: Tuple[int, int] = img_size self.window_size: int = window_size self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) - # Init patch embedding - self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, - patch_size=patch_size, norm_layer=norm_layer) - # Compute patch resolution - patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size) - # Path dropout dependent on depth - drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist() - # Init stages - self.stages: nn.ModuleList = nn.ModuleList() - for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)): - self.stages.append( + + self.patch_embed: nn.Module = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer) + patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size + + drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist() + stages = [] + for index, (depth, num_heads) in enumerate(zip(depths, num_heads)): + stage_scale = 2 ** max(index - 1, 0) + stages.append( SwinTransformerStage( - in_channels=embed_dim * (2 ** max(index - 1, 0)), + embed_dim=embed_dim * stage_scale, depth=depth, downscale=index != 0, - input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), - patch_resolution[1] // (2 ** max(index - 1, 0))), - number_of_heads=number_of_head, + feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale), + num_heads=num_heads, window_size=window_size, - ff_feature_ratio=mlp_ratio, - dropout=drop_rate, - dropout_attention=attn_drop_rate, - dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], - use_checkpoint=use_checkpoint, - sequential_self_attention=sequential_self_attention, - use_deformable_block=use_deformable_block and (index > 0), - norm_layer=norm_layer - )) - # Init final adaptive average pooling, and classification head - self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) - self.head: nn.Module = nn.Linear(in_features=self.num_features, - out_features=num_classes) - - def update_resolution(self, - new_input_resolution: Optional[Tuple[int, int]] = None, - new_window_size: Optional[int] = None) -> None: - """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_attn=attn_drop_rate, + drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], + grad_checkpointing=grad_checkpointing, + extra_norm_period=extra_norm_period, + sequential_attn=sequential_attn, + norm_layer=norm_layer, + ) + ) + self.stages = nn.Sequential(*stages) + + self.global_pool: str = global_pool + self.head: nn.Module = nn.Linear( + in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity() + + # FIXME weight init TBD, PyTorch default init appears to be working well, + # but differs from usual ViT or Swin init. + # named_apply(init_weights, self) + + def update_input_size( + self, + new_img_size: Optional[Tuple[int, int]] = None, + new_window_size: Optional[int] = None, + img_window_ratio: int = 32, + ) -> None: + """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. Args: - new_window_size (Optional[int]): New window size if None current window size is used - new_input_resolution (Optional[Tuple[int, int]]): New input resolution if None current resolution is used + new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div + new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used + img_window_ratio (int): divisor for calculating window size from image size """ # Check parameters - if new_input_resolution is None: - new_input_resolution = self.input_resolution + if new_img_size is None: + new_img_size = self.img_size + else: + new_img_size = to_2tuple(new_img_size) if new_window_size is None: - new_window_size = self.window_size - # Compute new patch resolution - new_patch_resolution: Tuple[int, int] = (new_input_resolution[0] // self.patch_size, - new_input_resolution[1] // self.patch_size) - # Update resolution of each stage - for index, stage in enumerate(self.stages): # type: int, SwinTransformerStage - stage.update_resolution(new_window_size=new_window_size, - new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)), - new_patch_resolution[1] // (2 ** max(index - 1, 0)))) + new_window_size = tuple([s // img_window_ratio for s in new_img_size]) + # Compute new patch resolution & update resolution of each stage + new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size) + for index, stage in enumerate(self.stages): + stage_scale = 2 ** max(index - 1, 0) + stage.update_input_size( + new_window_size=new_window_size, + new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), + ) def get_classifier(self) -> nn.Module: - """ Method returns the classification head of the model. + """Method returns the classification head of the model. Returns: head (nn.Module): Current classification head """ head: nn.Module = self.head return head - def reset_classifier(self, num_classes: int, global_pool: str = '') -> None: - """ Method results the classification head + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Method results the classification head Args: num_classes (int): Number of classes to be predicted global_pool (str): Unused """ self.num_classes: int = num_classes - self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) \ - if num_classes > 0 else nn.Identity() - - def forward_features(self, - input: torch.Tensor) -> List[torch.Tensor]: - """ Forward pass to extract feature maps of each stage. - - Args: - input (torch.Tensor): Input images of the shape (B, C, H, W) - - Returns: - features (List[torch.Tensor]): List of feature maps from each stage - """ - # Check input resolution - assert input.shape[2:] == self.input_resolution, \ - "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ - "update_resolution the provided method." - # Perform patch embedding - output: torch.Tensor = self.patch_embedding(input) - # Init list to store feature - features: List[torch.Tensor] = [] - # Forward pass of each stage - for stage in self.stages: - output: torch.Tensor = stage(output) - features.append(output) - return features - - def forward(self, - input: torch.Tensor) -> torch.Tensor: - """ Forward pass. - - Args: - input (torch.Tensor): Input images of the shape (B, C, H, W) - - Returns: - classification (torch.Tensor): Classification of the shape (B, num_classes) - """ - # Check input resolution - assert input.shape[2:] == self.input_resolution, \ - "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ - "update_resolution the provided method." - # Perform patch embedding - output: torch.Tensor = self.patch_embedding(input) - # Forward pass of each stage - for stage in self.stages: - output: torch.Tensor = stage(output) - # Perform average pooling - output: torch.Tensor = self.average_pool(output).flatten(start_dim=1) - # Predict classification - classification: torch.Tensor = self.head(output) - return classification + if global_pool is not None: + self.global_pool = global_pool + self.head: nn.Module = nn.Linear( + in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(2, 3)) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def init_weights(module: nn.Module, name: str = ''): + # FIXME WIP + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): @@ -1057,119 +767,168 @@ def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( - SwinTransformerV2CR, variant, pretrained, + SwinTransformerV2Cr, + variant, + pretrained, default_cfg=default_cfg, img_size=img_size, num_classes=num_classes, pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + **kwargs + ) return model @register_model -def swin_v2_cr_tiny_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-T V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=96, depths=(2, 2, 6, 2), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window12_384', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_tiny_384(pretrained=False, **kwargs): + """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_tiny_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-T V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_tiny_224(pretrained=False, **kwargs): + """Swin-T V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_small_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-S V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window12_384', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_small_384(pretrained=False, **kwargs): + """Swin-S V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr( + 'swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs + ) @register_model -def swin_v2_cr_small_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-S V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), - num_heads=(3, 6, 12, 24), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_small_224(pretrained=False, **kwargs): + """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-B V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), - num_heads=(4, 8, 16, 32), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_base_384(pretrained=False, **kwargs): + """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-B V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), - num_heads=(4, 8, 16, 32), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_base_224(pretrained=False, **kwargs): + """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_large_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-L V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), - num_heads=(6, 12, 24, 48), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_large_384(pretrained=False, **kwargs): + """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=192, + depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), + **kwargs + ) + return _create_swin_transformer_v2_cr( + 'swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs + ) @register_model -def swin_v2_cr_large_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-L V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), - num_heads=(6, 12, 24, 48), **kwargs) +def swin_v2_cr_large_224(pretrained=False, **kwargs): + """Swin-L V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=192, + depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), + **kwargs + ) return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-H V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=352, depths=(2, 2, 18, 2), - num_heads=(6, 12, 24, 48), **kwargs) +def swin_v2_cr_huge_384(pretrained=False, **kwargs): + """Swin-H V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=352, + depths=(2, 2, 18, 2), + num_heads=(11, 22, 44, 88), # head count not certain for Huge, 384 & 224 trying diff values + extra_norm_period=6, + **kwargs + ) return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-H V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=352, depths=(2, 2, 18, 2), - num_heads=(11, 22, 44, 88), **kwargs) +def swin_v2_cr_huge_224(pretrained=False, **kwargs): + """Swin-H V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=352, + depths=(2, 2, 18, 2), + num_heads=(8, 16, 32, 64), # head count not certain for Huge, 384 & 224 trying diff values + extra_norm_period=6, + **kwargs + ) return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_giant_patch4_window12_384(pretrained=False, **kwargs): - """ Swin-G V2 CR @ 384x384, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(384, 384), patch_size=4, window_size=12, embed_dim=512, depths=(2, 2, 18, 2), - num_heads=(16, 32, 64, 128), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window12_384', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_giant_384(pretrained=False, **kwargs): + """Swin-G V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=512, + depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr( + 'swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs + ) @register_model -def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs): - """ Swin-G V2 CR @ 224x224, trained ImageNet-1k - """ - model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2), - num_heads=(16, 32, 64, 128), **kwargs) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs) +def swin_v2_cr_giant_224(pretrained=False, **kwargs): + """Swin-G V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=512, + depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs) From 1420c118dfa1ba151a9cbd76f08db2701da23bbe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Feb 2022 19:50:26 -0800 Subject: [PATCH 19/21] Missed comitting outstanding changes to default_cfg keys and test exclusions for swin v2 --- tests/test_models.py | 5 +++-- timm/models/swin_transformer_v2_cr.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index bb98d43e..4b9f3428 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,8 +33,9 @@ if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*'] - NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', + 'swin*giant*'] + NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] else: EXCLUDE_FILTERS = [] NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index bad5488d..b2915bf8 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -813,8 +813,7 @@ def swin_v2_cr_small_384(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr( - 'swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs ) @@ -863,8 +862,7 @@ def swin_v2_cr_large_384(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr( - 'swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs ) @@ -877,7 +875,7 @@ def swin_v2_cr_large_224(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs) @register_model @@ -890,7 +888,7 @@ def swin_v2_cr_huge_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window12_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs) @register_model @@ -903,7 +901,7 @@ def swin_v2_cr_huge_224(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_patch4_window7_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs) @register_model @@ -916,8 +914,7 @@ def swin_v2_cr_giant_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr( - 'swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs ) From d98aa47d12d27d941fe019fb1ab9b52cde670056 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Mar 2022 12:29:02 -0700 Subject: [PATCH 20/21] Revert ml-decoder changes to model factory and train script --- timm/models/factory.py | 5 ----- train.py | 4 +--- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index 40453380..d040a9ff 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -29,7 +29,6 @@ def create_model( scriptable=None, exportable=None, no_jit=None, - use_ml_decoder_head=False, **kwargs): """Create a model @@ -81,10 +80,6 @@ def create_model( with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn(pretrained=pretrained, **kwargs) - if use_ml_decoder_head: - from timm.models.layers.ml_decoder import add_ml_decoder_head - model = add_ml_decoder_head(model) - if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/train.py b/train.py index 42985e12..10d839be 100755 --- a/train.py +++ b/train.py @@ -115,7 +115,6 @@ parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') -parser.add_argument('--use-ml-decoder-head', type=int, default=0) # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -380,8 +379,7 @@ def main(): bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint, - use_ml_decoder_head=args.use_ml_decoder_head) + checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly From 7cdd164d7775c444ee50afa6b825c843038f1b07 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Mar 2022 13:35:45 -0700 Subject: [PATCH 21/21] Fix #1184, scheduler noise bug during merge madness --- timm/scheduler/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 81af76f9..e7a6d2a7 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -92,6 +92,7 @@ class Scheduler: def _is_apply_noise(self, t) -> bool: """Return True if scheduler in noise range.""" + apply_noise = False if self.noise_range_t is not None: if isinstance(self.noise_range_t, (list, tuple)): apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] @@ -104,7 +105,7 @@ class Scheduler: g.manual_seed(self.noise_seed + t) if self.noise_type == 'normal': while True: - # resample if noise out of percent limit, brute force but shouldn't spin much + # resample if noise out of percent limit, brute force but shouldn't spin much noise = torch.randn(1, generator=g).item() if abs(noise) < self.noise_pct: return noise