diff --git a/README.md b/README.md index 3fa9701f..69effde4 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### Feb 2, 2022 +* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) +* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. + * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! + * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. + ### Jan 14, 2022 * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features @@ -410,6 +416,8 @@ Model validation results can be found in the [documentation](https://rwightman.g My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. +[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. + [timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. diff --git a/docs/feature_extraction.md b/docs/feature_extraction.md index b41c1559..3d638d65 100644 --- a/docs/feature_extraction.md +++ b/docs/feature_extraction.md @@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7]) ### Select specific feature levels or limit the stride -There are to additional creation arguments impacting the output features. +There are two additional creation arguments impacting the output features. * `out_indices` selects which indices to output * `output_stride` limits the feature output stride of the network (also works in classification mode BTW) diff --git a/tests/test_models.py b/tests/test_models.py index 11e84027..4e50de6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -34,8 +34,9 @@ if 'GITHUB_ACTIONS' in os.environ: EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*'] - NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', + 'swin*giant*'] + NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*'] else: EXCLUDE_FILTERS = [] NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 497aa53d..45ead5dc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -41,6 +41,7 @@ from .selecsls import * from .senet import * from .sknet import * from .swin_transformer import * +from .swin_transformer_v2_cr import * from .tnt import * from .tresnet import * from .twins import * diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py new file mode 100644 index 00000000..3f828c6d --- /dev/null +++ b/timm/models/layers/ml_decoder.py @@ -0,0 +1,156 @@ +from typing import Optional + +import torch +from torch import nn +from torch import nn, Tensor +from torch.nn.modules.transformer import _get_activation_fn + + +def add_ml_decoder_head(model): + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 + model.global_pool = nn.Identity() + del model.fc + num_classes = model.num_classes + num_features = model.num_features + model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') + del model.head + num_classes = model.num_classes + num_features = model.num_features + model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + else: + print("Model code-writing is not aligned currently with ml-decoder") + exit(-1) + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 + return model + + +class TransformerDecoderLayerOptimal(nn.Module): + def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", + layer_norm_eps=1e-5) -> None: + super(TransformerDecoderLayerOptimal, self).__init__() + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = torch.nn.functional.relu + super(TransformerDecoderLayerOptimal, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +# @torch.jit.script +# class ExtrapClasses(object): +# def __init__(self, num_queries: int, group_size: int): +# self.num_queries = num_queries +# self.group_size = group_size +# +# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: +# torch.Tensor): +# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) +# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) +# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) +# out = (h * w).sum(dim=2) + class_embed_b +# out = out.view((h.shape[0], self.group_size * self.num_queries)) +# return out + +@torch.jit.script +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] + w_i = duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) + + +class MLDecoder(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): + super(MLDecoder, self).__init__() + embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + if embed_len_decoder > num_classes: + embed_len_decoder = num_classes + + # switching to 768 initial embeddings + decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding + self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + + # decoder + decoder_dropout = 0.1 + num_layers_decoder = 1 + dim_feedforward = 2048 + layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + dim_feedforward=dim_feedforward, dropout=decoder_dropout) + self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) + + # non-learnable queries + self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed.requires_grad_(False) + + # group fully-connected + self.num_classes = num_classes + self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) + self.duplicate_pooling = torch.nn.Parameter( + torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) + torch.nn.init.xavier_normal_(self.duplicate_pooling) + torch.nn.init.constant_(self.duplicate_pooling_bias, 0) + self.group_fc = GroupFC(embed_len_decoder) + + def forward(self, x): + if len(x.shape) == 4: # [bs,2048, 7,7] + embedding_spatial = x.flatten(2).transpose(1, 2) + else: # [bs, 197,468] + embedding_spatial = x + embedding_spatial_786 = self.embed_standart(embedding_spatial) + embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) + + bs = embedding_spatial_786.shape[0] + query_embed = self.query_embed.weight + # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand + h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] + h = h.transpose(0, 1) + + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] + h_out += self.duplicate_pooling_bias + logits = h_out + return logits diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py new file mode 100644 index 00000000..b2915bf8 --- /dev/null +++ b/timm/models/swin_transformer_v2_cr.py @@ -0,0 +1,931 @@ +""" Swin Transformer V2 + +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/pdf/2111.09883 + +Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below + +This implementation is experimental and subject to change in manners that will break weight compat: +* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads? + * currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head) +* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at + GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial. +* num_heads per stage is not detailed for Huge and Giant model variants +* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts + +Noteworthy additions over official Swin v1: +* MLP relative position embedding is looking promising and adapts to different image/window sizes +* This impl has been designed to allow easy change of image size with matching window size changes +* Non-square image size and window size are supported + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 reimplementation +# Copyright (c) 2021 Christoph Reich +# Licensed under The MIT License [see LICENSE for details] +# Written by Christoph Reich +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Tuple, Optional, List, Union, Any, Type + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .layers import DropPath, Mlp, to_2tuple, _assert +from .registry import register_model +from .vision_transformer import checkpoint_filter_fn + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_v2_cr_tiny_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_tiny_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_small_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_small_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_base_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_base_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_large_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_large_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_huge_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_huge_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), + 'swin_v2_cr_giant_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0), + 'swin_v2_cr_giant_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=1.0), +} + + +def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """ + return x.permute(0, 2, 3, 1) + + +def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor: + """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """ + return x.permute(0, 3, 1, 2) + + +def window_partition(x, window_size: Tuple[int, int]): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: tuple[int, int], img_size: tuple[int, int]): + """ + Args: + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size + + Returns: + x: (B, H, W, C) + """ + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowMultiHeadAttention(nn.Module): + r"""This class implements window-based Multi-Head-Attention with log-spaced continuous position bias. + + Args: + dim (int): Number of input features + window_size (int): Window size + num_heads (int): Number of attention heads + drop_attn (float): Dropout rate of attention map + drop_proj (float): Dropout rate after projection + meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network + sequential_attn (bool): If true sequential self-attention is performed + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Tuple[int, int], + drop_attn: float = 0.0, + drop_proj: float = 0.0, + meta_hidden_dim: int = 384, # FIXME what's the optimal value? + sequential_attn: bool = False, + ) -> None: + super(WindowMultiHeadAttention, self).__init__() + assert dim % num_heads == 0, \ + "The number of input features (in_features) are not divisible by the number of heads (num_heads)." + self.in_features: int = dim + self.window_size: Tuple[int, int] = window_size + self.num_heads: int = num_heads + self.sequential_attn: bool = sequential_attn + + self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True) + self.attn_drop = nn.Dropout(drop_attn) + self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True) + self.proj_drop = nn.Dropout(drop_proj) + # meta network for positional encodings + self.meta_mlp = Mlp( + 2, # x, y + hidden_features=meta_hidden_dim, + out_features=num_heads, + act_layer=nn.ReLU, + drop=0. # FIXME should we add stochasticity? + ) + self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) + self._make_pair_wise_relative_positions() + + def _make_pair_wise_relative_positions(self) -> None: + """Method initializes the pair-wise relative positions to compute the positional biases.""" + device = self.tau.device + coordinates = torch.stack(torch.meshgrid([ + torch.arange(self.window_size[0], device=device), + torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) + relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :] + relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() + relative_coordinates_log = torch.sign(relative_coordinates) * torch.log( + 1.0 + relative_coordinates.abs()) + self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False) + + def update_input_size(self, new_window_size: int, **kwargs: Any) -> None: + """Method updates the window size and so the pair-wise relative positions + + Args: + new_window_size (int): New window size + kwargs (Any): Unused + """ + # Set new window size and new pair-wise relative positions + self.window_size: int = new_window_size + self._make_pair_wise_relative_positions() + + def _relative_positional_encodings(self) -> torch.Tensor: + """Method computes the relative positional encodings + + Returns: + relative_position_bias (torch.Tensor): Relative positional encodings + (1, number of heads, window size ** 2, window size ** 2) + """ + window_area = self.window_size[0] * self.window_size[1] + relative_position_bias = self.meta_mlp(self.relative_coordinates_log) + relative_position_bias = relative_position_bias.transpose(1, 0).reshape( + self.num_heads, window_area, window_area + ) + relative_position_bias = relative_position_bias.unsqueeze(0) + return relative_position_bias + + def _forward_sequential( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + """ + # FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory) + assert False, "not implemented" + + def _forward_batch( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """This function performs standard (non-sequential) scaled cosine self-attention. + """ + Bw, L, C = x.shape + + qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + query, key, value = qkv.unbind(0) + + # compute attention map with scaled cosine attention + denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) + attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) + attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) + attn = attn + self._relative_positional_encodings() + if mask is not None: + # Apply mask if utilized + num_win: int = mask.shape[0] + attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L) + attn = attn + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, L, L) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ Forward pass. + Args: + x (torch.Tensor): Input tensor of the shape (B * windows, N, C) + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + Output tensor of the shape [B * windows, N, C] + """ + if self.sequential_attn: + return self._forward_sequential(x, mask) + else: + return self._forward_batch(x, mask) + + +class SwinTransformerBlock(nn.Module): + r"""This class implements the Swin transformer block. + + Args: + dim (int): Number of input channels + num_heads (int): Number of attention heads to be utilized + feat_size (Tuple[int, int]): Input resolution + window_size (Tuple[int, int]): Window size to be utilized + shift_size (int): Shifting size to be used + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + drop (float): Dropout in input mapping + drop_attn (float): Dropout rate of attention map + drop_path (float): Dropout in main path + extra_norm (bool): Insert extra norm on 'main' branch if True + sequential_attn (bool): If true sequential self-attention is performed + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized + """ + + def __init__( + self, + dim: int, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + shift_size: Tuple[int, int] = (0, 0), + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: float = 0.0, + extra_norm: bool = False, + sequential_attn: bool = False, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super(SwinTransformerBlock, self).__init__() + self.dim: int = dim + self.feat_size: Tuple[int, int] = feat_size + self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) + self.window_area = self.window_size[0] * self.window_size[1] + + # attn branch + self.attn = WindowMultiHeadAttention( + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + drop_attn=drop_attn, + drop_proj=drop, + sequential_attn=sequential_attn, + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() + + # mlp branch + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + drop=drop, + out_features=dim, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity() + + # extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?) + self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() + + self._make_attention_mask() + + def _calc_window_shift(self, target_window_size): + window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] + shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _make_attention_mask(self) -> None: + """Method generates the attention mask used in shift case.""" + # Make masks for shift case + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: + """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_feat_size (Tuple[int, int]): New input resolution + """ + # Update input resolution + self.feat_size: Tuple[int, int] = new_feat_size + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size)) + self.window_area = self.window_size[0] * self.window_size[1] + self.attn.update_input_size(new_window_size=self.window_size) + self._make_attention_mask() + + def _shifted_window_attn(self, x): + H, W = self.feat_size + B, L, C = x.shape + x = x.view(B, H, W, C) + + # cyclic shift + if any(self.shift_size): + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C + + # reverse cyclic shift + if any(self.shift_size): + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x + + x = x.view(B, L, C) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, C, H, W] + """ + # NOTE post-norm branches (op -> norm -> drop) + x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant) + return x + + +class PatchMerging(nn.Module): + """ This class implements the patch merging as a strided convolution with a normalization before. + Args: + dim (int): Number of input channels + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. + """ + + def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: + super(PatchMerging, self).__init__() + self.norm = norm_layer(4 * dim) + self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward pass. + Args: + x (torch.Tensor): Input tensor of the shape [B, C, H, W] + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) + x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl + x = self.norm(x) + x = bhwc_to_bchw(self.reduction(x)) + return x + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class SwinTransformerStage(nn.Module): + r"""This class implements a stage of the Swin transformer including multiple layers. + + Args: + embed_dim (int): Number of input channels + depth (int): Depth of the stage (number of layers) + downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) + feat_size (Tuple[int, int]): input feature map size (H, W) + num_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + drop (float): Dropout in input mapping + drop_attn (float): Dropout rate of attention map + drop_path (float): Dropout in main path + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm + grad_checkpointing (bool): If true checkpointing is utilized + extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks + sequential_attn (bool): If true sequential self-attention is performed + """ + + def __init__( + self, + embed_dim: int, + depth: int, + downscale: bool, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + mlp_ratio: float = 4.0, + drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: Union[List[float], float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + grad_checkpointing: bool = False, + extra_norm_period: int = 0, + sequential_attn: bool = False, + ) -> None: + super(SwinTransformerStage, self).__init__() + self.downscale: bool = downscale + self.grad_checkpointing: bool = grad_checkpointing + self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size + + self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity() + + embed_dim = embed_dim * 2 if downscale else embed_dim + self.blocks = nn.Sequential(*[ + SwinTransformerBlock( + dim=embed_dim, + num_heads=num_heads, + feat_size=self.feat_size, + window_size=window_size, + shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), + mlp_ratio=mlp_ratio, + drop=drop, + drop_attn=drop_attn, + drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, + extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False, + sequential_attn=sequential_attn, + norm_layer=norm_layer, + ) + for index in range(depth)] + ) + + def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None: + """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_feat_size (Tuple[int, int]): New input resolution + """ + self.feat_size: Tuple[int, int] = ( + (new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size + ) + for block in self.blocks: + block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass. + Args: + x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C] + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + x = self.downsample(x) + B, C, H, W = x.shape + L = H * W + + x = bchw_to_bhwc(x).reshape(B, L, C) + for block in self.blocks: + # Perform checkpointing if utilized + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + x = bhwc_to_bchw(x.reshape(B, H, W, -1)) + return x + + +class SwinTransformerV2Cr(nn.Module): + r""" Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - + https://arxiv.org/pdf/2111.09883 + + Args: + img_size (Tuple[int, int]): Input resolution. + window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None + img_window_ratio (int): Window size to image size ratio. Default: 32 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input channels. + depths (int): Depth of the stage (number of layers). + num_heads (int): Number of attention heads to be utilized. + embed_dim (int): Patch embedding dimension. Default: 96 + num_classes (int): Number of output classes. Default: 1000 + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4 + drop_rate (float): Dropout rate. Default: 0.0 + attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default: 0.0 + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm + grad_checkpointing (bool): If true checkpointing is utilized. Default: False + sequential_attn (bool): If true sequential self-attention is performed. Default: False + use_deformable (bool): If true deformable block is used. Default: False + """ + + def __init__( + self, + img_size: Tuple[int, int] = (224, 224), + patch_size: int = 4, + window_size: Optional[int] = None, + img_window_ratio: int = 32, + in_chans: int = 3, + num_classes: int = 1000, + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + mlp_ratio: float = 4.0, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, + grad_checkpointing: bool = False, + extra_norm_period: int = 0, + sequential_attn: bool = False, + global_pool: str = 'avg', + **kwargs: Any + ) -> None: + super(SwinTransformerV2Cr, self).__init__() + img_size = to_2tuple(img_size) + window_size = tuple([ + s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size) + + self.num_classes: int = num_classes + self.patch_size: int = patch_size + self.img_size: Tuple[int, int] = img_size + self.window_size: int = window_size + self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) + + self.patch_embed: nn.Module = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer) + patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size + + drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist() + stages = [] + for index, (depth, num_heads) in enumerate(zip(depths, num_heads)): + stage_scale = 2 ** max(index - 1, 0) + stages.append( + SwinTransformerStage( + embed_dim=embed_dim * stage_scale, + depth=depth, + downscale=index != 0, + feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale), + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_attn=attn_drop_rate, + drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], + grad_checkpointing=grad_checkpointing, + extra_norm_period=extra_norm_period, + sequential_attn=sequential_attn, + norm_layer=norm_layer, + ) + ) + self.stages = nn.Sequential(*stages) + + self.global_pool: str = global_pool + self.head: nn.Module = nn.Linear( + in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity() + + # FIXME weight init TBD, PyTorch default init appears to be working well, + # but differs from usual ViT or Swin init. + # named_apply(init_weights, self) + + def update_input_size( + self, + new_img_size: Optional[Tuple[int, int]] = None, + new_window_size: Optional[int] = None, + img_window_ratio: int = 32, + ) -> None: + """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div + new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used + img_window_ratio (int): divisor for calculating window size from image size + """ + # Check parameters + if new_img_size is None: + new_img_size = self.img_size + else: + new_img_size = to_2tuple(new_img_size) + if new_window_size is None: + new_window_size = tuple([s // img_window_ratio for s in new_img_size]) + # Compute new patch resolution & update resolution of each stage + new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size) + for index, stage in enumerate(self.stages): + stage_scale = 2 ** max(index - 1, 0) + stage.update_input_size( + new_window_size=new_window_size, + new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale), + ) + + def get_classifier(self) -> nn.Module: + """Method returns the classification head of the model. + Returns: + head (nn.Module): Current classification head + """ + head: nn.Module = self.head + return head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + """Method results the classification head + + Args: + num_classes (int): Number of classes to be predicted + global_pool (str): Unused + """ + self.num_classes: int = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head: nn.Module = nn.Linear( + in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(2, 3)) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def init_weights(module: nn.Module, name: str = ''): + # FIXME WIP + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformerV2Cr, + variant, + pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs + ) + + return model + + +@register_model +def swin_v2_cr_tiny_384(pretrained=False, **kwargs): + """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_tiny_224(pretrained=False, **kwargs): + """Swin-T V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_small_384(pretrained=False, **kwargs): + """Swin-S V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs + ) + + +@register_model +def swin_v2_cr_small_224(pretrained=False, **kwargs): + """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_384(pretrained=False, **kwargs): + """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_base_224(pretrained=False, **kwargs): + """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_large_384(pretrained=False, **kwargs): + """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=192, + depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs + ) + + +@register_model +def swin_v2_cr_large_224(pretrained=False, **kwargs): + """Swin-L V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=192, + depths=(2, 2, 18, 2), + num_heads=(6, 12, 24, 48), + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_384(pretrained=False, **kwargs): + """Swin-H V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=352, + depths=(2, 2, 18, 2), + num_heads=(11, 22, 44, 88), # head count not certain for Huge, 384 & 224 trying diff values + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_huge_224(pretrained=False, **kwargs): + """Swin-H V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=352, + depths=(2, 2, 18, 2), + num_heads=(8, 16, 32, 64), # head count not certain for Huge, 384 & 224 trying diff values + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_v2_cr_giant_384(pretrained=False, **kwargs): + """Swin-G V2 CR @ 384x384, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=512, + depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs + ) + + +@register_model +def swin_v2_cr_giant_224(pretrained=False, **kwargs): + """Swin-G V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=512, + depths=(2, 2, 42, 2), + num_heads=(16, 32, 64, 128), + extra_norm_period=6, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index cefc8116..0eee2044 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -6,7 +6,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in: - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.TODO + - https://arxiv.org/abs/2106.10270 NOTE These hybrid model definitions depend on code in vision_transformer.py. They were moved here to keep file sizes sane. @@ -359,4 +359,4 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer_hybrid( 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 4f2cacb6..fbfc531f 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler): min_lr=lr_min ) - self.noise_range = noise_range_t + self.noise_range_t = noise_range_t self.noise_pct = noise_pct self.noise_type = noise_type self.noise_std = noise_std @@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler): self.lr_scheduler.step(metric, epoch) # step the base scheduler - if self.noise_range is not None: - if isinstance(self.noise_range, (list, tuple)): - apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] - else: - apply_noise = epoch >= self.noise_range - if apply_noise: - self._apply_noise(epoch) + if self._is_apply_noise(epoch): + self._apply_noise(epoch) + def _apply_noise(self, epoch): - g = torch.Generator() - g.manual_seed(self.noise_seed + epoch) - if self.noise_type == 'normal': - while True: - # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + noise = self._calculate_noise(epoch) # apply the noise on top of previous LR, cache the old value so we can restore for normal # stepping of base scheduler diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 226d0e76..af20be9b 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -88,21 +88,30 @@ class Scheduler: param_group[self.param_group_field] = value def _add_noise(self, lrs, t): + if self._is_apply_noise(t): + noise = self._calculate_noise(t) + lrs = [v + v * noise for v in lrs] + return lrs + + def _is_apply_noise(self, t) -> bool: + """Return True if scheduler in noise range.""" + apply_noise = False if self.noise_range_t is not None: if isinstance(self.noise_range_t, (list, tuple)): apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] else: apply_noise = t >= self.noise_range_t - if apply_noise: - g = torch.Generator() - g.manual_seed(self.noise_seed + t) - if self.noise_type == 'normal': - while True: - # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct - lrs = [v + v * noise for v in lrs] - return lrs + return apply_noise + + def _calculate_noise(self, t) -> float: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + return noise + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + return noise diff --git a/train.py b/train.py index 0efef787..cc2b2b4d 100755 --- a/train.py +++ b/train.py @@ -108,7 +108,7 @@ parser.add_argument('--crop-pct', default=None, type=float, parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', - help='Override std deviation of of dataset') + help='Override std deviation of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',