Merge remote-tracking branch 'origin/master' into norm_norm_norm

pull/1014/head
Ross Wightman 3 years ago
commit b049a5c5c6

@ -23,6 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### Feb 2, 2022
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.
* The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs!
* `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable.
### Jan 14, 2022 ### Jan 14, 2022
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
@ -410,6 +416,8 @@ Model validation results can be found in the [documentation](https://rwightman.g
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.

@ -145,7 +145,7 @@ torch.Size([2, 1512, 7, 7])
### Select specific feature levels or limit the stride ### Select specific feature levels or limit the stride
There are to additional creation arguments impacting the output features. There are two additional creation arguments impacting the output features.
* `out_indices` selects which indices to output * `out_indices` selects which indices to output
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW) * `output_stride` limits the feature output stride of the network (also works in classification mode BTW)

@ -34,8 +34,9 @@ if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*'] '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*'] 'swin*giant*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
NON_STD_EXCLUDE_FILTERS = ['vit_gi*'] NON_STD_EXCLUDE_FILTERS = ['vit_gi*']

@ -41,6 +41,7 @@ from .selecsls import *
from .senet import * from .senet import *
from .sknet import * from .sknet import *
from .swin_transformer import * from .swin_transformer import *
from .swin_transformer_v2_cr import *
from .tnt import * from .tnt import *
from .tresnet import * from .tresnet import *
from .twins import * from .twins import *

@ -0,0 +1,156 @@
from typing import Optional
import torch
from torch import nn
from torch import nn, Tensor
from torch.nn.modules.transformer import _get_activation_fn
def add_ml_decoder_head(model):
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
model.global_pool = nn.Identity()
del model.fc
num_classes = model.num_classes
num_features = model.num_features
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
model.global_pool = nn.Identity()
del model.classifier
num_classes = model.num_classes
num_features = model.num_features
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
del model.head
num_classes = model.num_classes
num_features = model.num_features
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
else:
print("Model code-writing is not aligned currently with ml-decoder")
exit(-1)
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
model.drop_rate = 0
return model
class TransformerDecoderLayerOptimal(nn.Module):
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
layer_norm_eps=1e-5) -> None:
super(TransformerDecoderLayerOptimal, self).__init__()
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = torch.nn.functional.relu
super(TransformerDecoderLayerOptimal, self).__setstate__(state)
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
tgt = tgt + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
# @torch.jit.script
# class ExtrapClasses(object):
# def __init__(self, num_queries: int, group_size: int):
# self.num_queries = num_queries
# self.group_size = group_size
#
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
# torch.Tensor):
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
# out = (h * w).sum(dim=2) + class_embed_b
# out = out.view((h.shape[0], self.group_size * self.num_queries))
# return out
@torch.jit.script
class GroupFC(object):
def __init__(self, embed_len_decoder: int):
self.embed_len_decoder = embed_len_decoder
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
for i in range(self.embed_len_decoder):
h_i = h[:, i, :]
w_i = duplicate_pooling[i, :, :]
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
class MLDecoder(nn.Module):
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
super(MLDecoder, self).__init__()
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
if embed_len_decoder > num_classes:
embed_len_decoder = num_classes
# switching to 768 initial embeddings
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
# decoder
decoder_dropout = 0.1
num_layers_decoder = 1
dim_feedforward = 2048
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
# non-learnable queries
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
self.query_embed.requires_grad_(False)
# group fully-connected
self.num_classes = num_classes
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
self.duplicate_pooling = torch.nn.Parameter(
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
torch.nn.init.xavier_normal_(self.duplicate_pooling)
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
self.group_fc = GroupFC(embed_len_decoder)
def forward(self, x):
if len(x.shape) == 4: # [bs,2048, 7,7]
embedding_spatial = x.flatten(2).transpose(1, 2)
else: # [bs, 197,468]
embedding_spatial = x
embedding_spatial_786 = self.embed_standart(embedding_spatial)
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
bs = embedding_spatial_786.shape[0]
query_embed = self.query_embed.weight
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
h = h.transpose(0, 1)
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
self.group_fc(h, self.duplicate_pooling, out_extrap)
h_out = out_extrap.flatten(1)[:, :self.num_classes]
h_out += self.duplicate_pooling_bias
logits = h_out
return logits

@ -0,0 +1,931 @@
""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/pdf/2111.09883
Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below
This implementation is experimental and subject to change in manners that will break weight compat:
* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads?
* currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head)
* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
* num_heads per stage is not detailed for Huge and Giant model variants
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
Noteworthy additions over official Swin v1:
* MLP relative position embedding is looking promising and adapts to different image/window sizes
* This impl has been designed to allow easy change of image size with matching window size changes
* Non-square image size and window size are supported
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer V2 reimplementation
# Copyright (c) 2021 Christoph Reich
# Licensed under The MIT License [see LICENSE for details]
# Written by Christoph Reich
# --------------------------------------------------------
import logging
import math
from copy import deepcopy
from typing import Tuple, Optional, List, Union, Any, Type
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
from .layers import DropPath, Mlp, to_2tuple, _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN,
'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj',
'classifier': 'head',
**kwargs,
}
default_cfgs = {
# patch models (my experiments)
'swin_v2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_small_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
'swin_v2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=1.0),
}
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
return x.permute(0, 2, 3, 1)
def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
"""Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """
return x.permute(0, 3, 1, 2)
def window_partition(x, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: tuple[int, int], img_size: tuple[int, int]):
"""
Args:
windows: (num_windows * B, window_size[0], window_size[1], C)
window_size (Tuple[int, int]): Window size
img_size (Tuple[int, int]): Image size
Returns:
x: (B, H, W, C)
"""
H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowMultiHeadAttention(nn.Module):
r"""This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.
Args:
dim (int): Number of input features
window_size (int): Window size
num_heads (int): Number of attention heads
drop_attn (float): Dropout rate of attention map
drop_proj (float): Dropout rate after projection
meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network
sequential_attn (bool): If true sequential self-attention is performed
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: Tuple[int, int],
drop_attn: float = 0.0,
drop_proj: float = 0.0,
meta_hidden_dim: int = 384, # FIXME what's the optimal value?
sequential_attn: bool = False,
) -> None:
super(WindowMultiHeadAttention, self).__init__()
assert dim % num_heads == 0, \
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
self.in_features: int = dim
self.window_size: Tuple[int, int] = window_size
self.num_heads: int = num_heads
self.sequential_attn: bool = sequential_attn
self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True)
self.attn_drop = nn.Dropout(drop_attn)
self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True)
self.proj_drop = nn.Dropout(drop_proj)
# meta network for positional encodings
self.meta_mlp = Mlp(
2, # x, y
hidden_features=meta_hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
drop=0. # FIXME should we add stochasticity?
)
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
self._make_pair_wise_relative_positions()
def _make_pair_wise_relative_positions(self) -> None:
"""Method initializes the pair-wise relative positions to compute the positional biases."""
device = self.tau.device
coordinates = torch.stack(torch.meshgrid([
torch.arange(self.window_size[0], device=device),
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(
1.0 + relative_coordinates.abs())
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
"""Method updates the window size and so the pair-wise relative positions
Args:
new_window_size (int): New window size
kwargs (Any): Unused
"""
# Set new window size and new pair-wise relative positions
self.window_size: int = new_window_size
self._make_pair_wise_relative_positions()
def _relative_positional_encodings(self) -> torch.Tensor:
"""Method computes the relative positional encodings
Returns:
relative_position_bias (torch.Tensor): Relative positional encodings
(1, number of heads, window size ** 2, window size ** 2)
"""
window_area = self.window_size[0] * self.window_size[1]
relative_position_bias = self.meta_mlp(self.relative_coordinates_log)
relative_position_bias = relative_position_bias.transpose(1, 0).reshape(
self.num_heads, window_area, window_area
)
relative_position_bias = relative_position_bias.unsqueeze(0)
return relative_position_bias
def _forward_sequential(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
"""
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
assert False, "not implemented"
def _forward_batch(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""This function performs standard (non-sequential) scaled cosine self-attention.
"""
Bw, L, C = x.shape
qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
query, key, value = qkv.unbind(0)
# compute attention map with scaled cosine attention
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1)
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6)
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1)
attn = attn + self._relative_positional_encodings()
if mask is not None:
# Apply mask if utilized
num_win: int = mask.shape[0]
attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L)
attn = attn + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, L, L)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
Output tensor of the shape [B * windows, N, C]
"""
if self.sequential_attn:
return self._forward_sequential(x, mask)
else:
return self._forward_batch(x, mask)
class SwinTransformerBlock(nn.Module):
r"""This class implements the Swin transformer block.
Args:
dim (int): Number of input channels
num_heads (int): Number of attention heads to be utilized
feat_size (Tuple[int, int]): Input resolution
window_size (Tuple[int, int]): Window size to be utilized
shift_size (int): Shifting size to be used
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
drop (float): Dropout in input mapping
drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path
extra_norm (bool): Insert extra norm on 'main' branch if True
sequential_attn (bool): If true sequential self-attention is performed
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
"""
def __init__(
self,
dim: int,
num_heads: int,
feat_size: Tuple[int, int],
window_size: Tuple[int, int],
shift_size: Tuple[int, int] = (0, 0),
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_attn: float = 0.0,
drop_path: float = 0.0,
extra_norm: bool = False,
sequential_attn: bool = False,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super(SwinTransformerBlock, self).__init__()
self.dim: int = dim
self.feat_size: Tuple[int, int] = feat_size
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
self.window_area = self.window_size[0] * self.window_size[1]
# attn branch
self.attn = WindowMultiHeadAttention(
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
drop_attn=drop_attn,
drop_proj=drop,
sequential_attn=sequential_attn,
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
# mlp branch
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
drop=drop,
out_features=dim,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?)
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask()
def _calc_window_shift(self, target_window_size):
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _make_attention_mask(self) -> None:
"""Method generates the attention mask used in shift case."""
# Make masks for shift case
if any(self.shift_size):
# calculate attention mask for SW-MSA
H, W = self.feat_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size[0]),
slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size[0], None)):
for w in (
slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_area)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask, persistent=False)
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_feat_size (Tuple[int, int]): New input resolution
"""
# Update input resolution
self.feat_size: Tuple[int, int] = new_feat_size
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
self.window_area = self.window_size[0] * self.window_size[1]
self.attn.update_input_size(new_window_size=self.window_size)
self._make_attention_mask()
def _shifted_window_attn(self, x):
H, W = self.feat_size
B, L, C = x.shape
x = x.view(B, H, W, C)
# cyclic shift
if any(self.shift_size):
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
# reverse cyclic shift
if any(self.shift_size):
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
x = x.view(B, L, C)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
"""
# NOTE post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant)
return x
class PatchMerging(nn.Module):
""" This class implements the patch merging as a strided convolution with a normalization before.
Args:
dim (int): Number of input channels
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized.
"""
def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
super(PatchMerging, self).__init__()
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl
x = self.norm(x)
x = bhwc_to_bchw(self.reduction(x))
return x
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class SwinTransformerStage(nn.Module):
r"""This class implements a stage of the Swin transformer including multiple layers.
Args:
embed_dim (int): Number of input channels
depth (int): Depth of the stage (number of layers)
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
feat_size (Tuple[int, int]): input feature map size (H, W)
num_heads (int): Number of attention heads to be utilized
window_size (int): Window size to be utilized
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
drop (float): Dropout in input mapping
drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
sequential_attn (bool): If true sequential self-attention is performed
"""
def __init__(
self,
embed_dim: int,
depth: int,
downscale: bool,
num_heads: int,
feat_size: Tuple[int, int],
window_size: Tuple[int, int],
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0,
sequential_attn: bool = False,
) -> None:
super(SwinTransformerStage, self).__init__()
self.downscale: bool = downscale
self.grad_checkpointing: bool = grad_checkpointing
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
embed_dim = embed_dim * 2 if downscale else embed_dim
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
dim=embed_dim,
num_heads=num_heads,
feat_size=self.feat_size,
window_size=window_size,
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
mlp_ratio=mlp_ratio,
drop=drop,
drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False,
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
for index in range(depth)]
)
def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_feat_size (Tuple[int, int]): New input resolution
"""
self.feat_size: Tuple[int, int] = (
(new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
)
for block in self.blocks:
block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C]
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
x = self.downsample(x)
B, C, H, W = x.shape
L = H * W
x = bchw_to_bhwc(x).reshape(B, L, C)
for block in self.blocks:
# Perform checkpointing if utilized
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
return x
class SwinTransformerV2Cr(nn.Module):
r""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` -
https://arxiv.org/pdf/2111.09883
Args:
img_size (Tuple[int, int]): Input resolution.
window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None
img_window_ratio (int): Window size to image size ratio. Default: 32
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input channels.
depths (int): Depth of the stage (number of layers).
num_heads (int): Number of attention heads to be utilized.
embed_dim (int): Patch embedding dimension. Default: 96
num_classes (int): Number of output classes. Default: 1000
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
drop_rate (float): Dropout rate. Default: 0.0
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
grad_checkpointing (bool): If true checkpointing is utilized. Default: False
sequential_attn (bool): If true sequential self-attention is performed. Default: False
use_deformable (bool): If true deformable block is used. Default: False
"""
def __init__(
self,
img_size: Tuple[int, int] = (224, 224),
patch_size: int = 4,
window_size: Optional[int] = None,
img_window_ratio: int = 32,
in_chans: int = 3,
num_classes: int = 1000,
embed_dim: int = 96,
depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
grad_checkpointing: bool = False,
extra_norm_period: int = 0,
sequential_attn: bool = False,
global_pool: str = 'avg',
**kwargs: Any
) -> None:
super(SwinTransformerV2Cr, self).__init__()
img_size = to_2tuple(img_size)
window_size = tuple([
s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size)
self.num_classes: int = num_classes
self.patch_size: int = patch_size
self.img_size: Tuple[int, int] = img_size
self.window_size: int = window_size
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
self.patch_embed: nn.Module = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer)
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
stages = []
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
stage_scale = 2 ** max(index - 1, 0)
stages.append(
SwinTransformerStage(
embed_dim=embed_dim * stage_scale,
depth=depth,
downscale=index != 0,
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
grad_checkpointing=grad_checkpointing,
extra_norm_period=extra_norm_period,
sequential_attn=sequential_attn,
norm_layer=norm_layer,
)
)
self.stages = nn.Sequential(*stages)
self.global_pool: str = global_pool
self.head: nn.Module = nn.Linear(
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
# FIXME weight init TBD, PyTorch default init appears to be working well,
# but differs from usual ViT or Swin init.
# named_apply(init_weights, self)
def update_input_size(
self,
new_img_size: Optional[Tuple[int, int]] = None,
new_window_size: Optional[int] = None,
img_window_ratio: int = 32,
) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
img_window_ratio (int): divisor for calculating window size from image size
"""
# Check parameters
if new_img_size is None:
new_img_size = self.img_size
else:
new_img_size = to_2tuple(new_img_size)
if new_window_size is None:
new_window_size = tuple([s // img_window_ratio for s in new_img_size])
# Compute new patch resolution & update resolution of each stage
new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
for index, stage in enumerate(self.stages):
stage_scale = 2 ** max(index - 1, 0)
stage.update_input_size(
new_window_size=new_window_size,
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
)
def get_classifier(self) -> nn.Module:
"""Method returns the classification head of the model.
Returns:
head (nn.Module): Current classification head
"""
head: nn.Module = self.head
return head
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
"""Method results the classification head
Args:
num_classes (int): Number of classes to be predicted
global_pool (str): Unused
"""
self.num_classes: int = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head: nn.Module = nn.Linear(
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(2, 3))
return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
def init_weights(module: nn.Module, name: str = ''):
# FIXME WIP
if isinstance(module, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
SwinTransformerV2Cr,
variant,
pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs
)
return model
@register_model
def swin_v2_cr_tiny_384(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_small_384(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs
)
@register_model
def swin_v2_cr_small_224(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_base_384(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_base_224(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_large_384(pretrained=False, **kwargs):
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs
)
@register_model
def swin_v2_cr_large_224(pretrained=False, **kwargs):
"""Swin-L V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=192,
depths=(2, 2, 18, 2),
num_heads=(6, 12, 24, 48),
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_huge_384(pretrained=False, **kwargs):
"""Swin-H V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=352,
depths=(2, 2, 18, 2),
num_heads=(11, 22, 44, 88), # head count not certain for Huge, 384 & 224 trying diff values
extra_norm_period=6,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_huge_224(pretrained=False, **kwargs):
"""Swin-H V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=352,
depths=(2, 2, 18, 2),
num_heads=(8, 16, 32, 64), # head count not certain for Huge, 384 & 224 trying diff values
extra_norm_period=6,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_giant_384(pretrained=False, **kwargs):
"""Swin-G V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=512,
depths=(2, 2, 42, 2),
num_heads=(16, 32, 64, 128),
extra_norm_period=6,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs
)
@register_model
def swin_v2_cr_giant_224(pretrained=False, **kwargs):
"""Swin-G V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=512,
depths=(2, 2, 42, 2),
num_heads=(16, 32, 64, 128),
extra_norm_period=6,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs)

@ -6,7 +6,7 @@ A PyTorch implement of the Hybrid Vision Transformers as described in:
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO - https://arxiv.org/abs/2106.10270
NOTE These hybrid model definitions depend on code in vision_transformer.py. NOTE These hybrid model definitions depend on code in vision_transformer.py.
They were moved here to keep file sizes sane. They were moved here to keep file sizes sane.

@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min min_lr=lr_min
) )
self.noise_range = noise_range_t self.noise_range_t = noise_range_t
self.noise_pct = noise_pct self.noise_pct = noise_pct
self.noise_type = noise_type self.noise_type = noise_type
self.noise_std = noise_std self.noise_std = noise_std
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
self.lr_scheduler.step(metric, epoch) # step the base scheduler self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self.noise_range is not None: if self._is_apply_noise(epoch):
if isinstance(self.noise_range, (list, tuple)): self._apply_noise(epoch)
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
def _apply_noise(self, epoch): def _apply_noise(self, epoch):
g = torch.Generator() noise = self._calculate_noise(epoch)
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
# apply the noise on top of previous LR, cache the old value so we can restore for normal # apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler # stepping of base scheduler

@ -88,21 +88,30 @@ class Scheduler:
param_group[self.param_group_field] = value param_group[self.param_group_field] = value
def _add_noise(self, lrs, t): def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None: if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)): if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else: else:
apply_noise = t >= self.noise_range_t apply_noise = t >= self.noise_range_t
if apply_noise: return apply_noise
g = torch.Generator()
g.manual_seed(self.noise_seed + t) def _calculate_noise(self, t) -> float:
if self.noise_type == 'normal': g = torch.Generator()
while True: g.manual_seed(self.noise_seed + t)
# resample if noise out of percent limit, brute force but shouldn't spin much if self.noise_type == 'normal':
noise = torch.randn(1, generator=g).item() while True:
if abs(noise) < self.noise_pct: # resample if noise out of percent limit, brute force but shouldn't spin much
break noise = torch.randn(1, generator=g).item()
else: if abs(noise) < self.noise_pct:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct return noise
lrs = [v + v * noise for v in lrs] else:
return lrs noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
return noise

@ -108,7 +108,7 @@ parser.add_argument('--crop-pct', default=None, type=float,
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset') help='Override std deviation of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',

Loading…
Cancel
Save