Significant model refactor and additions:

* All models updated with revised foward_features / forward_head interface
* Vision transformer and MLP based models consistently output sequence from forward_features (pooling or token selection considered part of 'head')
* WIP param grouping interface to allow consistent grouping of parameters for layer-wise decay across all model types
* Add gradient checkpointing support to a significant % of models, especially popular architectures
* Formatting and interface consistency improvements across models
* layer-wise LR decay impl part of optimizer factory w/ scale support in scheduler
* Poolformer and Volo architectures added
pull/1014/head
Ross Wightman 3 years ago
parent 2c3870e107
commit 372ad5fa0d

@ -89,6 +89,8 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--precision', default='float32', type=str,
@ -322,6 +324,9 @@ class TrainBenchmarkRunner(BenchmarkRunner):
opt=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4))
if kwargs.pop('grad_checkpointing', False):
self.model.set_grad_checkpointing()
def _gen_target(self, batch_size):
return torch.empty(
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)

@ -24,7 +24,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*']
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
@ -144,7 +145,7 @@ def test_model_default_cfgs(model_name, batch_size):
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2], 'unpooled feature shape != config'
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
@ -156,8 +157,8 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
if 'pruned' not in model_name: # FIXME better pruned model handling
@ -165,8 +166,7 @@ def test_model_default_cfgs(model_name, batch_size):
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.VGG)):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg
@ -204,9 +204,11 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = model.forward_features(input_tensor)
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
# cannot currently verify multi-tensor output.
pass
else:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
@ -214,7 +216,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor)
@ -319,13 +321,18 @@ def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
tracer_kwargs = dict(
leaf_modules=list(_leaf_modules),
autowrap_functions=list(_autowrap_functions),
#enable_cpatching=True,
param_shapes_constant=True
)
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
@ -334,8 +341,11 @@ def _create_fx_model(model, train=False):
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model,
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
return fx_model

@ -108,7 +108,13 @@ class RepeatAugSampler(Sampler):
indices = torch.arange(start=0, end=len(self.dataset))
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
if isinstance(self.num_repeats, float) and not self.num_repeats.is_integer():
# resample for repeats w/ non-integer ratio
repeat_size = math.ceil(self.num_repeats * len(self.dataset))
indices = indices[torch.tensor([int(i // self.num_repeats) for i in range(repeat_size)])]
else:
indices = torch.repeat_interleave(indices, repeats=int(self.num_repeats), dim=0)
indices = indices.tolist() # leaving as tensor thrashes dataloader memory
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size > 0:

@ -7,6 +7,7 @@ Hacked together by / Copyright 2019, Ross Wightman
"""
import random
from functools import partial
from itertools import repeat
from typing import Callable
import torch.utils.data
@ -54,20 +55,37 @@ def fast_collate(batch):
assert False
def expand_to_chs(x, n):
if not isinstance(x, (tuple, list)):
x = tuple(repeat(x, n))
elif len(x) == 1:
x = x * n
else:
assert len(x) == n, 'normalization stats must match image channels'
return x
class PrefetchLoader:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
mean = expand_to_chs(mean, channels)
std = expand_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
@ -247,6 +265,7 @@ def create_loader(
loader,
mean=mean,
std=std,
channels=input_size[0],
fp16=fp16,
re_prob=prefetch_re_prob,
re_mode=re_mode,

@ -30,6 +30,7 @@ from .nest import *
from .nfnet import *
from .pit import *
from .pnasnet import *
from .poolformer import *
from .regnet import *
from .res2net import *
from .resnest import *
@ -47,6 +48,7 @@ from .vgg import *
from .visformer import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .volo import *
from .vovnet import *
from .xception import *
from .xception_aligned import *

@ -20,11 +20,12 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
# --------------------------------------------------------'
import math
from functools import partial
from typing import Optional
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
@ -71,6 +72,28 @@ default_cfgs = {
}
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
window_area = window_size[0] * window_size[1]
coords = torch.stack(torch.meshgrid(
[torch.arange(window_size[0]),
torch.arange(window_size[1])])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
@ -98,26 +121,7 @@ class Attention(nn.Module):
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
else:
self.window_size = None
self.relative_position_bias_table = None
@ -127,8 +131,17 @@ class Attention(nn.Module):
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
def _get_rel_pos_bias(self):
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@ -138,15 +151,9 @@ class Attention(nn.Module):
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn + self._get_rel_pos_bias()
if shared_rel_pos_bias is not None:
attn = attn + shared_rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
@ -159,9 +166,10 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
@ -174,17 +182,17 @@ class Block(nn.Module):
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
@ -194,37 +202,15 @@ class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
self.window_area = window_size[0] * window_size[1]
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
# trunc_normal_(self.relative_position_bias_table, std=.02)
self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
@ -242,6 +228,7 @@ class Beit(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -258,7 +245,6 @@ class Beit(nn.Module):
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
@ -298,45 +284,63 @@ class Beit(nn.Module):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
nwd = {'pos_embed', 'cls_token'}
for n, _ in self.named_parameters():
if 'relative_position_bias_table' in n:
nwd.add(n)
return nwd
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
)
return matcher
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
else:
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
if self.fc_norm is not None:
x = x[:, 1:].mean(dim=1)
x = self.fc_norm(x)
else:
x = x[:, 0]
x = self.head(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -33,7 +33,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
@ -159,9 +159,9 @@ default_cfgs = {
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
crop_pct=0.94),
'regnetz_d8_evob': _cfgr(
'regnetz_c16_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94),
'regnetz_d8_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
@ -621,20 +621,19 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_d8_evob=ByoModelCfg(
regnetz_c16_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
),
stem_chs=64,
stem_type='tiered',
stem_chs=32,
stem_pool='',
downsample='',
num_features=1792,
num_features=1536,
act_layer='silu',
norm_layer='evonormb0',
norm_layer=partial(EvoNorm2dS0a, group_size=16),
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
@ -888,10 +887,10 @@ def regnetz_b16_evos(pretrained=False, **kwargs):
@register_model
def regnetz_d8_evob(pretrained=False, **kwargs):
def regnetz_c16_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs)
return _create_byobnet('regnetz_c16_evos', pretrained=pretrained, **kwargs)
@register_model
@ -1200,9 +1199,10 @@ class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1269,8 +1269,9 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
super().__init__()
assert stride in (2, 4)
layers = layers or LayerFn()
@ -1479,11 +1480,13 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
def __init__(
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
@ -1514,6 +1517,22 @@ class ByobNet(nn.Module):
# init weights
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).(\d+)', None),
(r'^final_conv', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -1522,13 +1541,19 @@ class ByobNet(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.final_conv(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -9,13 +9,13 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from copy import deepcopy
from functools import partial
import torch
import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model
@ -202,14 +202,13 @@ class Cait(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to adapt to our cait models
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
global_pool=None,
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
block_layers=LayerScaleBlock,
block_layers_token=LayerScaleBlockClassAttn,
patch_layer=PatchEmbed,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
@ -220,9 +219,12 @@ class Cait(nn.Module):
mlp_ratio_token_only=4.0
):
super().__init__()
assert global_pool in ('', 'token', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim
self.grad_checkpointing = False
self.patch_embed = patch_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -271,32 +273,61 @@ class Cait(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def group_matcher(self, coarse=False):
def _matcher(name):
if any([name.startswith(n) for n in ('cls_token', 'pos_embed', 'patch_embed')]):
return 0
elif name.startswith('blocks.'):
return int(name.split('.')[1]) + 1
elif name.startswith('blocks_token_only.'):
# overlap token only blocks with last blocks
to_offset = len(self.blocks) - len(self.blocks_token_only) + 1
return int(name.split('.')[1]) + to_offset
elif name.startswith('norm.'):
return len(self.blocks)
else:
return float('inf')
return _matcher
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
x = self.blocks(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for i, blk in enumerate(self.blocks_token_only):
cls_tokens = blk(x, cls_tokens)
x = torch.cat((cls_tokens, x), dim=1)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
x = self.forward_head(x)
return x

@ -9,7 +9,7 @@ Modified from timm/models/vision_transformer.py
"""
from copy import deepcopy
from functools import partial
from typing import Tuple, List
from typing import Tuple, List, Union
import torch
import torch.nn as nn
@ -125,7 +125,7 @@ class ConvRelPosEnc(nn.Module):
return EV_hat
class FactorAtt_ConvRelPosEnc(nn.Module):
class FactorAttnConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
super().__init__()
@ -205,7 +205,7 @@ class SerialBlock(nn.Module):
self.cpe = shared_cpe
self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
self.factoratt_crpe = FactorAttnConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -239,15 +239,15 @@ class ParallelBlock(nn.Module):
self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1]
)
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2]
)
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3]
)
@ -328,17 +328,19 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module):
""" CoaT class. """
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
return_interm_layers=False, out_features=None, crpe_window=None, global_pool='token'):
super().__init__()
assert global_pool in ('token', 'avg')
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.num_classes = num_classes
self.global_pool = global_pool
# Patch embeddings.
img_size = to_2tuple(img_size)
@ -470,61 +472,73 @@ class CoaT(nn.Module):
def no_weight_decay(self):
return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
serial_blocks1=r'^serial_blocks1\.(\d+)',
stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
serial_blocks2=r'^serial_blocks2\.(\d+)',
stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
serial_blocks3=r'^serial_blocks3\.(\d+)',
stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
serial_blocks4=r'^serial_blocks4\.(\d+)',
parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
(r'^parallel_blocks\.(\d+)', None),
(r'^norm|aggregate', (99999,)),
]
)
return matcher
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def insert_cls(self, x, cls_token):
""" Insert CLS token. """
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
return x
def remove_cls(self, x):
""" Remove CLS token. """
return x[:, 1:, :]
def forward_features(self, x0):
B = x0.shape[0]
# Serial blocks 1.
x1 = self.patch_embed1(x0)
H1, W1 = self.patch_embed1.grid_size
x1 = self.insert_cls(x1, self.cls_token1)
x1 = insert_cls(x1, self.cls_token1)
for blk in self.serial_blocks1:
x1 = blk(x1, size=(H1, W1))
x1_nocls = self.remove_cls(x1)
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 2.
x2 = self.patch_embed2(x1_nocls)
H2, W2 = self.patch_embed2.grid_size
x2 = self.insert_cls(x2, self.cls_token2)
x2 = insert_cls(x2, self.cls_token2)
for blk in self.serial_blocks2:
x2 = blk(x2, size=(H2, W2))
x2_nocls = self.remove_cls(x2)
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 3.
x3 = self.patch_embed3(x2_nocls)
H3, W3 = self.patch_embed3.grid_size
x3 = self.insert_cls(x3, self.cls_token3)
x3 = insert_cls(x3, self.cls_token3)
for blk in self.serial_blocks3:
x3 = blk(x3, size=(H3, W3))
x3_nocls = self.remove_cls(x3)
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 4.
x4 = self.patch_embed4(x3_nocls)
H4, W4 = self.patch_embed4.grid_size
x4 = self.insert_cls(x4, self.cls_token4)
x4 = insert_cls(x4, self.cls_token4)
for blk in self.serial_blocks4:
x4 = blk(x4, size=(H4, W4))
x4_nocls = self.remove_cls(x4)
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
# Only serial blocks: Early return.
if self.parallel_blocks is None:
@ -554,20 +568,16 @@ class CoaT(nn.Module):
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
x1_nocls = self.remove_cls(x1)
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x1_nocls'] = x1_nocls
if 'x2_nocls' in self.out_features:
x2_nocls = self.remove_cls(x2)
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x2_nocls'] = x2_nocls
if 'x3_nocls' in self.out_features:
x3_nocls = self.remove_cls(x3)
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x3_nocls'] = x3_nocls
if 'x4_nocls' in self.out_features:
x4_nocls = self.remove_cls(x4)
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
feat_out['x4_nocls'] = x4_nocls
return feat_out
else:
@ -576,6 +586,18 @@ class CoaT(nn.Module):
x4 = self.norm4(x4)
return [x2, x3, x4]
def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
if isinstance(x_feat, list):
assert self.aggregate is not None
if self.global_pool == 'avg':
x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
else:
x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
else:
x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x) -> torch.Tensor:
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features (for down-stream tasks).
@ -583,15 +605,22 @@ class CoaT(nn.Module):
else:
# Return features for classification.
x_feat = self.forward_features(x)
if isinstance(x_feat, (tuple, list)):
x = torch.cat([xl[:, :1] for xl in x_feat], dim=1) # [B, 3, C]
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
else:
x = x_feat[:, 0]
x = self.head(x)
x = self.forward_head(x_feat)
return x
def insert_cls(x, cls_token):
""" Insert CLS token. """
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
return x
def remove_cls(x):
""" Remove CLS token. """
return x[:, 1:, :]
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
for k, v in state_dict.items():

@ -61,8 +61,8 @@ default_cfgs = {
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
@ -169,7 +169,7 @@ class MHSA(nn.Module):
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
indd = indx ** 2 + indy ** 2
distances = indd ** .5
distances = distances.to('cuda')
distances = distances.to(x.device)
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
if return_map:
@ -180,7 +180,7 @@ class MHSA(nn.Module):
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
@ -194,8 +194,9 @@ class MHSA(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
super().__init__()
self.norm1 = norm_layer(dim)
self.use_gpsa = use_gpsa
@ -219,13 +220,16 @@ class ConViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
super().__init__()
assert global_pool in ('', 'avg', 'token')
embed_dim *= num_heads
self.num_classes = num_classes
self.global_pool = global_pool
self.local_up_to_layer = local_up_to_layer
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.locality_strength = locality_strength
@ -285,35 +289,49 @@ class ConViT(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.use_pos_embed:
x = x + self.pos_embed
x = self.pos_drop(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for u, blk in enumerate(self.blocks):
if u == self.local_up_to_layer:
x = torch.cat((cls_tokens, x), dim=1)
x = blk(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
x = self.forward_head(x)
return x

@ -1,7 +1,13 @@
""" ConvMixer
"""
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.registry import register_model
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import SelectAdaptivePool2d
def _cfg(url='', **kwargs):
@ -32,49 +38,68 @@ class Residual(nn.Module):
class ConvMixer(nn.Module):
def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs):
def __init__(
self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg',
act_layer=nn.GELU, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
self.grad_checkpointing = False
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
activation(),
act_layer(),
nn.BatchNorm2d(dim)
)
self.blocks = nn.Sequential(
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
activation(),
act_layer(),
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
activation(),
act_layer(),
nn.BatchNorm2d(dim)
) for i in range(depth)]
)
self.pooling = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten()
)
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.pooling(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.pooling(x)
x = self.head(x)
x = self.forward_head(x)
return x
@ -90,7 +115,7 @@ def convmixer_1536_20(pretrained=False, **kwargs):
@register_model
def convmixer_768_32(pretrained=False, **kwargs):
model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs)
model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
return _create_convmixer('convmixer_768_32', pretrained, **model_args)

@ -18,7 +18,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
from .registry import register_model
@ -43,6 +43,7 @@ default_cfgs = dict(
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''),
convnext_tiny_hnf=_cfg(url=''),
convnext_base_in22ft1k=_cfg(
@ -151,6 +152,7 @@ class ConvNeXtStage(nn.Module):
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,
norm_layer=None, cl_norm_layer=None, cross_stage=False):
super().__init__()
self.grad_checkpointing = False
if in_chs != out_chs or stride > 1:
self.downsample = nn.Sequential(
@ -169,7 +171,10 @@ class ConvNeXtStage(nn.Module):
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
@ -190,7 +195,7 @@ class ConvNeXt(nn.Module):
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch',
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
):
super().__init__()
@ -208,19 +213,29 @@ class ConvNeXt(nn.Module):
self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
norm_layer(dims[0])
)
if stem_type == 'patch':
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
norm_layer(dims[0])
)
curr_stride = patch_size
prev_chs = dims[0]
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1),
norm_layer(32),
nn.GELU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
)
curr_stride = 2
prev_chs = 64
self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
curr_stride = patch_size
prev_chs = dims[0]
stages = []
# 4 feature resolution stages, each consisting of multiple residual blocks
for i in range(4):
stride = 2 if i > 0 else 1
stride = 2 if curr_stride == 2 or i > 0 else 1
# FIXME support dilation / output_stride
curr_stride *= stride
out_chs = dims[i]
@ -235,40 +250,43 @@ class ConvNeXt(nn.Module):
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
if head_norm_first:
# norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
else:
# pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = nn.Identity()
self.head = nn.Sequential(OrderedDict([
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', norm_layer(self.num_features)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool='avg'):
if isinstance(self.head, ClassifierHead):
# norm -> global pool -> fc
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
else:
# pool -> norm -> fc
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', self.head.norm),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
@ -276,9 +294,17 @@ class ConvNeXt(nn.Module):
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x
@ -326,19 +352,34 @@ def _create_convnext(variant, pretrained=False, **kwargs):
@register_model
def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
def convnext_nano_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs)
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny_hnfd(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)

@ -175,7 +175,6 @@ class CrossAttentionBlock(nn.Module):
def forward(self, x):
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
return x
@ -289,12 +288,14 @@ class CrossViT(nn.Module):
def __init__(
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False,
multi_conv=False, crop_scale=False, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool='token',
):
super().__init__()
assert global_pool in ('token', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.img_size = to_2tuple(img_size)
img_scale = to_2tuple(img_scale)
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
@ -302,7 +303,7 @@ class CrossViT(nn.Module):
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
self.num_branches = len(patch_size)
self.embed_dim = embed_dim
self.num_features = embed_dim[0] # to pass the tests
self.num_features = sum(embed_dim)
self.patch_embed = nn.ModuleList()
# hard-coded for torch jit script
@ -359,11 +360,26 @@ class CrossViT(nn.Module):
out.add(f'pos_embed_{i}')
return out
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
self.global_pool = global_pool
self.head = nn.ModuleList(
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])
@ -391,12 +407,16 @@ class CrossViT(nn.Module):
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
return xs
def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
if pre_logits or isinstance(self.head[0], nn.Identity):
return torch.cat([x for x in xs], dim=1)
return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
def forward(self, x):
xs = self.forward_features(x)
ce_logits = [head(xs[i][:, 0]) for i, head in enumerate(self.head)]
if not isinstance(self.head[0], nn.Identity):
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits
x = self.forward_head(xs)
return x
def _create_crossvit(variant, pretrained=False, **kwargs):

@ -12,11 +12,13 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer
from .registry import register_model
@ -172,7 +174,7 @@ class ResBottleneck(nn.Module):
self.drop_path = drop_path
self.act3 = act_layer(inplace=True)
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
@ -210,7 +212,7 @@ class DarkBlock(nn.Module):
self.attn = create_attn(attn_layer, channels=out_chs)
self.drop_path = drop_path
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
@ -345,9 +347,10 @@ class CspNet(nn.Module):
darknet impl. I did it this way for simplicity and less special cases.
"""
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck):
def __init__(
self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
zero_init_last=True, stage_fn=CrossStage, block_fn=ResBottleneck):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -378,20 +381,25 @@ class CspNet(nn.Module):
self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
(r'^stages.(\d+).*transition', MATCH_PREV_GROUP), # map to last block in stage
(r'^stages.(\d+)', (0,)),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -403,12 +411,28 @@ class CspNet(nn.Module):
x = self.stages(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x
def _init_weights(module, name, zero_init_last=False):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
elif zero_init_last and hasattr(module, 'zero_init_last'):
module.zero_init_last()
def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0]
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]

@ -13,7 +13,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .registry import register_model
@ -66,10 +66,13 @@ class VisionTransformerDistilled(VisionTransformer):
def __init__(self, *args, **kwargs):
weight_init = kwargs.pop('weight_init', '')
super().__init__(*args, **kwargs, weight_init='skip')
assert self.global_pool in ('token',)
self.num_tokens = 2
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
self.init_weights(weight_init)
@ -77,32 +80,50 @@ class VisionTransformerDistilled(VisionTransformer):
trunc_normal_(self.dist_token, std=.02)
super().init_weights(mode=mode)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed|dist_token',
blocks=[
(r'^blocks.(\d+)', None),
(r'^norm', (99999,))] # final norm w/ last block
)
@torch.jit.ignore
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_features(self, x) -> torch.Tensor:
x = self.patch_embed(x)
x = torch.cat((
self.cls_token.expand(x.shape[0], -1, -1),
self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x_dist = self.head_dist(x[:, 1])
x = self.head(x[:, 0])
if self.training and not torch.jit.is_scripting():
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
if pre_logits:
return (x[:, 0] + x[:, 1]) / 2
x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist
else:
# during inference, return the average of both classifier predictions
# during standard train / finetune, inference average the classifier predictions
return (x + x_dist) / 2

@ -13,7 +13,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, MATCH_PREV_GROUP
from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from .registry import register_model
@ -162,10 +162,10 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
num_classes=1000, in_chans=3, global_pool='avg',
norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
aa_stem_only=True):
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
aa_stem_only=True):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
@ -249,6 +249,18 @@ class DenseNet(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv[012]|features.norm[012]|features.pool[012]',
blocks=r'^features.(?:denseblock|transition)(\d+)' if coarse else [
(r'^features.denseblock(\d+).denselayer(\d+)', None),
(r'^features.transition(\d+)', MATCH_PREV_GROUP) # FIXME combine with previous denselayer
]
)
return matcher
@torch.jit.ignore
def get_classifier(self):
return self.classifier

@ -6,6 +6,7 @@ Res2Net additions from: https://github.com/gasvn/Res2Net/
Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
"""
import math
from typing import List, Optional
import torch
import torch.nn as nn
@ -62,7 +63,7 @@ class DlaBasic(nn.Module):
self.bn2 = nn.BatchNorm2d(planes)
self.stride = stride
def forward(self, x, shortcut=None):
def forward(self, x, shortcut=None, children: Optional[List[torch.Tensor]] = None):
if shortcut is None:
shortcut = x
@ -99,7 +100,7 @@ class DlaBottleneck(nn.Module):
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, shortcut=None):
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
if shortcut is None:
shortcut = x
@ -147,14 +148,13 @@ class DlaBottle2neck(nn.Module):
bns.append(nn.BatchNorm2d(mid_planes))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
if self.is_first:
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None
self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(outplanes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, shortcut=None):
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
if shortcut is None:
shortcut = x
@ -164,14 +164,21 @@ class DlaBottle2neck(nn.Module):
spx = torch.split(out, self.width, 1)
spo = []
sp = spx[0] # redundant, for torchscript
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
if i == 0 or self.is_first:
sp = spx[i]
else:
sp = sp + spx[i]
sp = conv(sp)
sp = bn(sp)
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1:
spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
if self.pool is not None: # self.is_first == True, None check for torchscript
spo.append(self.pool(spx[-1]))
else:
spo.append(spx[-1])
out = torch.cat(spo, 1)
out = self.conv3(out)
@ -192,21 +199,20 @@ class DlaRoot(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.shortcut = shortcut
def forward(self, *x):
children = x
x = self.conv(torch.cat(x, 1))
def forward(self, x_children: List[torch.Tensor]):
x = self.conv(torch.cat(x_children, 1))
x = self.bn(x)
if self.shortcut:
x += children[0]
x += x_children[0]
x = self.relu(x)
return x
class DlaTree(nn.Module):
def __init__(self, levels, block, in_channels, out_channels, stride=1,
dilation=1, cardinality=1, base_width=64,
level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
def __init__(
self, levels, block, in_channels, out_channels, stride=1, dilation=1, cardinality=1,
base_width=64, level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
super(DlaTree, self).__init__()
if root_dim == 0:
root_dim = 2 * out_channels
@ -225,38 +231,39 @@ class DlaTree(nn.Module):
self.project = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels))
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
else:
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
self.tree1 = DlaTree(
levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
self.tree2 = DlaTree(
levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
if levels == 1:
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
self.root = None
self.level_root = level_root
self.root_dim = root_dim
self.levels = levels
def forward(self, x, shortcut=None, children=None):
children = [] if children is None else children
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
if children is None:
children = []
bottom = self.downsample(x)
shortcut = self.project(bottom)
if self.level_root:
children.append(bottom)
x1 = self.tree1(x, shortcut)
if self.levels == 1:
if self.root is not None: # levels == 1
x2 = self.tree2(x1)
x = self.root(x2, x1, *children)
x = self.root([x2, x1] + children)
else:
children.append(x1)
x = self.tree2(x1, children=children)
x = self.tree2(x1, None, children)
return x
class DLA(nn.Module):
def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
drop_rate=0.0, global_pool='avg'):
def __init__(
self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, global_pool='avg',
cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, drop_rate=0.0):
super(DLA, self).__init__()
self.channels = channels
self.num_classes = num_classes
@ -302,13 +309,32 @@ class DLA(nn.Module):
modules = []
for i in range(convs):
modules.extend([
nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
padding=dilation, bias=False, dilation=dilation),
nn.Conv2d(
inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
padding=dilation, bias=False, dilation=dilation),
nn.BatchNorm2d(planes),
nn.ReLU(inplace=True)])
inplanes = planes
return nn.Sequential(*modules)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^base_layer',
blocks=r'^level(\d+)' if coarse else [
# an unusual arch, this achieves somewhat more granularity without getting super messy
(r'^level(\d+).tree(\d+)', None),
(r'^level(\d+).root', (2,)),
(r'^level(\d+)', (1,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.fc
@ -328,13 +354,19 @@ class DLA(nn.Module):
x = self.level5(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
x = self.flatten(x)
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -166,16 +166,17 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module):
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU):
def __init__(
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg',
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU):
super(DPN, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.b = b
assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False)
bw_factor = 1 if small else 4
blocks = OrderedDict()
@ -239,6 +240,22 @@ class DPN(nn.Module):
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^features.conv1',
blocks=[
(r'^features.conv(\d+)' if coarse else r'^features.conv(\d+)_(\d+)', None),
(r'^features.conv5_bn_ac', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -251,13 +268,19 @@ class DPN(nn.Module):
def forward_features(self, x):
return self.features(x)
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
x = self.flatten(x)
if pre_logits:
return x.flatten(1)
else:
x = self.classifier(x)
return self.flatten(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -48,7 +48,7 @@ from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model
@ -470,9 +470,10 @@ class EfficientNet(nn.Module):
* TinyNet
"""
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
def __init__(
self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(EfficientNet, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -481,6 +482,7 @@ class EfficientNet(nn.Module):
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem
if not fix_stem:
@ -511,6 +513,21 @@ class EfficientNet(nn.Module):
layers.extend([nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'conv_head|bn2', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -522,17 +539,24 @@ class EfficientNet(nn.Module):
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
x = self.conv_head(x)
x = self.bn2(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
return x if pre_logits else self.classifier(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
class EfficientNetFeatures(nn.Module):
@ -542,9 +566,10 @@ class EfficientNetFeatures(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
def __init__(
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(EfficientNetFeatures, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d

@ -86,7 +86,7 @@ class FeatureHooks:
This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torcscript.
redesign for torchscript.
"""
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
@ -97,7 +97,7 @@ class FeatureHooks:
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
hook_type = h.get('hook_type', default_hook_type)
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':

@ -89,13 +89,13 @@ class FeatureGraphNet(nn.Module):
return list(self.graph_module(x).values())
class FeatureExtractNet(nn.Module):
class GraphExtractNet(nn.Module):
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
NOTE:
* one can use feature_extractor directly if dictionary output is desired
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
metadata for builtin feature extraction mode
* feature_extractor can be used directly if dictionary output is acceptable
* create_feature_extractor can be used directly if dictionary output is acceptable
Args:
model: model to extract features from

@ -15,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .registry import register_model
@ -24,7 +24,7 @@ __all__ = ['GhostNet']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
@ -133,13 +133,15 @@ class GhostBottleneck(nn.Module):
class GhostNet(nn.Module):
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
def __init__(
self, cfgs, num_classes=1000, width=1.0, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.2):
super(GhostNet, self).__init__()
# setting of inverted residual blocks
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
self.cfgs = cfgs
self.num_classes = num_classes
self.dropout = dropout
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
# building first layer
@ -184,6 +186,24 @@ class GhostNet(nn.Module):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
# FIXME init
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^conv_stem|bn1',
blocks=[
(r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)', None),
(r'conv_head', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -198,18 +218,25 @@ class GhostNet(nn.Module):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x
def forward_head(self, x):
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.flatten(x)
if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier(x)
x = self.forward_head(x)
return x

@ -8,6 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -178,6 +179,23 @@ class Xception65(nn.Module):
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=[
(r'^mid.block(\d+)', None),
(r'^block(\d+)', None),
(r'^conv[345]|bn[345]', (99,)),
],
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, "gradient checkpointing not supported"
@torch.jit.ignore
def get_classifier(self):
return self.fc
@ -222,14 +240,18 @@ class Xception65(nn.Module):
x = self.act5(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x):
x = self.global_pool(x)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(

@ -13,7 +13,7 @@ from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',

@ -2,16 +2,20 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import collections.abc
import logging
import os
import math
from collections import OrderedDict
import os
import re
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Any, Callable, Optional, Tuple, Dict
from itertools import chain
from typing import Any, Callable, Optional, Tuple, Dict, Union
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torch.utils.checkpoint import checkpoint
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
@ -68,7 +72,8 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
@ -479,7 +484,7 @@ def build_model_with_cfg(
pretrained_cfg: Optional[Dict] = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_strict: bool = False,
pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False,
kwargs_filter: Optional[Tuple[str]] = None,
@ -592,3 +597,194 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
if module._parameters and not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules_with_params(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if module._parameters and depth_first and include_root:
yield name, module
MATCH_PREV_GROUP = (99999,)
def group_with_matcher(
named_objects,
group_matcher: Union[Dict, Callable],
output_values: bool = False,
reverse: bool = False
):
if isinstance(group_matcher, dict):
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
compiled = []
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
if mspec is None:
continue
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
if isinstance(mspec, (tuple, list)):
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
for sspec in mspec:
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
else:
compiled += [(re.compile(mspec), (group_ordinal,), None)]
group_matcher = compiled
def _get_grouping(name):
if isinstance(group_matcher, (list, tuple)):
for match_fn, prefix, suffix in group_matcher:
r = match_fn.match(name)
if r:
parts = (prefix, r.groups(), suffix)
# map all tuple elem to int for numeric sort, filter out None entries
return tuple(map(float, chain.from_iterable(filter(None, parts))))
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
else:
ord = group_matcher(name)
if not isinstance(ord, collections.abc.Iterable):
return ord,
return tuple(ord)
# map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list)
for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k)
# remap to integers
layer_id_to_param = defaultdict(list)
lid = -1
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
lid += 1
print(lid, k, grouping[k])
layer_id_to_param[lid].extend(grouping[k])
if reverse:
assert not output_values, "reverse mapping only sensible for name output"
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
for n in lm:
param_to_layer_id[n] = lid
return param_to_layer_id
return layer_id_to_param
def group_parameters(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
def group_modules(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
):
return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
def checkpoint_seq(
functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True
):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward
if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)
num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str):
if module_types == 'container':
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
else:
module_types = (nn.Sequential,)
for name, module in named_modules:
if depth and isinstance(module, module_types):
yield from flatten_modules(
module.named_children(),
depth - 1,
prefix=(name,) if prefix_is_tuple else name,
module_types=module_types,
)
else:
if prefix_is_tuple:
name = prefix + (name,)
yield name, module
else:
if prefix:
name = '.'.join([prefix, name])
yield name, module

@ -386,13 +386,13 @@ cfg_cls = dict(
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
def __init__(self, num_branches, blocks, num_blocks, num_in_chs,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
num_branches, blocks, num_blocks, num_in_chs, num_channels)
self.num_inchannels = num_inchannels
self.num_in_chs = num_in_chs
self.fuse_method = fuse_method
self.num_branches = num_branches
@ -403,32 +403,32 @@ class HighResolutionModule(nn.Module):
self.fuse_layers = self._make_fuse_layers()
self.fuse_act = nn.ReLU(False)
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
def _check_branches(self, num_branches, blocks, num_blocks, num_in_chs, num_channels):
error_msg = ''
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
elif num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels))
elif num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels))
elif num_branches != len(num_in_chs):
error_msg = 'NUM_BRANCHES({}) <> num_in_chs({})'.format(num_branches, len(num_in_chs))
if error_msg:
_logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
downsample = None
if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
if stride != 1 or self.num_in_chs[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion,
self.num_in_chs[branch_index], num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM),
)
layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)]
self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
layers = [block(self.num_in_chs[branch_index], num_channels[branch_index], stride, downsample)]
self.num_in_chs[branch_index] = num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
layers.append(block(self.num_in_chs[branch_index], num_channels[branch_index]))
return nn.Sequential(*layers)
@ -444,15 +444,15 @@ class HighResolutionModule(nn.Module):
return nn.Identity()
num_branches = self.num_branches
num_inchannels = self.num_inchannels
num_in_chs = self.num_in_chs
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
nn.Conv2d(num_in_chs[j], num_in_chs[i], 1, 1, 0, bias=False),
nn.BatchNorm2d(num_in_chs[i], momentum=_BN_MOMENTUM),
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
elif j == i:
fuse_layer.append(nn.Identity())
@ -460,14 +460,14 @@ class HighResolutionModule(nn.Module):
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
num_outchannels_conv3x3 = num_in_chs[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
num_outchannels_conv3x3 = num_in_chs[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
nn.Conv2d(num_in_chs[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM),
nn.ReLU(False)))
fuse_layer.append(nn.Sequential(*conv3x3s))
@ -475,8 +475,8 @@ class HighResolutionModule(nn.Module):
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def get_num_in_chs(self):
return self.num_in_chs
def forward(self, x: List[torch.Tensor]):
if self.num_branches == 1:
@ -652,7 +652,7 @@ class HighResolutionNet(nn.Module):
return nn.Sequential(*layers)
def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
def _make_stage(self, layer_config, num_in_chs, multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
@ -665,12 +665,13 @@ class HighResolutionNet(nn.Module):
# multi_scale_output is only used last module
reset_multi_scale_output = multi_scale_output or i < num_modules - 1
modules.append(HighResolutionModule(
num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output)
num_branches, block, num_blocks, num_in_chs, num_channels, fuse_method, reset_multi_scale_output)
)
num_inchannels = modules[-1].get_num_inchannels()
num_in_chs = modules[-1].get_num_in_chs()
return nn.Sequential(*modules), num_inchannels
return nn.Sequential(*modules), num_in_chs
@torch.jit.ignore
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
@ -680,6 +681,23 @@ class HighResolutionNet(nn.Module):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^conv[12]|bn[12]',
blocks=r'^(?:layer|stage|transition)(\d+)' if coarse else [
(r'^layer(\d+).(\d+)', None),
(r'^stage(\d+).(\d+)', None),
(r'^transition(\d+)', (99999,)),
],
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, "gradient checkpointing not supported"
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -712,20 +730,24 @@ class HighResolutionNet(nn.Module):
# Stages
yl = self.stages(x)
# Classification Head
if self.incre_modules is None or self.downsamp_modules is None:
return yl
y = self.incre_modules[0](yl[0])
for i, down in enumerate(self.downsamp_modules):
y = self.incre_modules[i + 1](yl[i + 1]) + down(y)
y = self.final_layer(y)
return y
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
# Classification Head
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
return x if pre_logits else self.classifier(x)
def forward(self, x):
y = self.forward_features(x)
x = self.forward_head(y)
return x

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, flatten_modules
from .layers import create_classifier
from .registry import register_model
@ -300,6 +300,30 @@ class InceptionResnetV2(nn.Module):
self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
@torch.jit.ignore
def group_matcher(self, coarse=False):
module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
module_map.pop(('classif',))
def _matcher(name):
if any([name.startswith(n) for n in ('conv2d_1', 'conv2d_2')]):
return 0
elif any([name.startswith(n) for n in ('conv2d_3', 'conv2d_4')]):
return 1
elif any([name.startswith(n) for n in ('block8', 'conv2d_7')]):
return len(module_map) + 1
else:
for k in module_map.keys():
if k == tuple(name.split('.')[:len(k)]):
return module_map[k]
return float('inf')
return _matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, "checkpointing not supported"
@torch.jit.ignore
def get_classifier(self):
return self.classif
@ -325,12 +349,15 @@ class InceptionResnetV2(nn.Module):
x = self.conv2d_7b(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classif(x)
return x if pre_logits else self.classif(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -8,7 +8,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, flatten_modules
from .registry import register_model
from .layers import trunc_normal_, create_classifier, Linear
@ -336,47 +336,57 @@ class InceptionV3(nn.Module):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
module_map = {k: i for i, (k, _) in enumerate(flatten_modules(self.named_children(), prefix=()))}
module_map.pop(('fc',))
def _matcher(name):
if any([name.startswith(n) for n in ('Conv2d_1', 'Conv2d_2')]):
return 0
elif any([name.startswith(n) for n in ('Conv2d_3', 'Conv2d_4')]):
return 1
else:
for k in module_map.keys():
if k == tuple(name.split('.')[:len(k)]):
return module_map[k]
return float('inf')
return _matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_preaux(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = self.Pool1(x)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = self.Pool2(x)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147
x = self.Pool1(x) # N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71
x = self.Pool2(x) # N x 192 x 35 x 35
x = self.Mixed_5b(x) # N x 256 x 35 x 35
x = self.Mixed_5c(x) # N x 288 x 35 x 35
x = self.Mixed_5d(x) # N x 288 x 35 x 35
x = self.Mixed_6a(x) # N x 768 x 17 x 17
x = self.Mixed_6b(x) # N x 768 x 17 x 17
x = self.Mixed_6c(x) # N x 768 x 17 x 17
x = self.Mixed_6d(x) # N x 768 x 17 x 17
x = self.Mixed_6e(x) # N x 768 x 17 x 17
return x
def forward_postaux(self, x):
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
x = self.Mixed_7a(x) # N x 1280 x 8 x 8
x = self.Mixed_7b(x) # N x 2048 x 8 x 8
x = self.Mixed_7c(x) # N x 2048 x 8 x 8
return x
def forward_features(self, x):
@ -384,21 +394,18 @@ class InceptionV3(nn.Module):
x = self.forward_postaux(x)
return x
def get_classifier(self):
return self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
class InceptionV3Aux(InceptionV3):
"""InceptionV3 with AuxLogits
@ -416,10 +423,7 @@ class InceptionV3Aux(InceptionV3):
def forward(self, x):
x, aux = self.forward_features(x)
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
x = self.forward_head(x)
return x, aux

@ -283,6 +283,18 @@ class InceptionV4(nn.Module):
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^features\.[012]\.',
blocks=r'^features\.(\d+)'
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.last_linear
@ -294,12 +306,15 @@ class InceptionV4(nn.Module):
def forward_features(self, x):
return self.features(x)
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)
return x if pre_logits else self.last_linear(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -7,67 +7,16 @@ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/cli
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from typing import List, Union, Tuple
from typing import Union, Tuple
import torch
import torch.nn as nn
from .helpers import to_2tuple
from .pos_embed import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_freq=4):
super().__init__()
self.dim = dim
self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False)
def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None):
"""
NOTE: shape arg should include spatial dim only
"""
device = device or self.bands.device
dtype = dtype or self.bands.dtype
if not isinstance(shape, torch.Size):
shape = torch.Size(shape)
N = shape.numel()
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1)
emb = grid * math.pi * self.bands
sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1)
cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1)
return sin, cos
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
class RotAttentionPool2d(nn.Module):
""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
@ -103,7 +52,6 @@ class RotAttentionPool2d(nn.Module):
def forward(self, x):
B, _, H, W = x.shape
N = H * W
sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:])
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
@ -112,6 +60,7 @@ class RotAttentionPool2d(nn.Module):
q, k, v = x[0], x[1], x[2]
qc, q = q[:, :, :1], q[:, :, 1:]
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
q = apply_rot_embed(q, sin_emb, cos_emb)
q = torch.cat([qc, q], dim=2)

@ -45,10 +45,12 @@ class ClassifierHead(nn.Module):
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x):
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
x = self.flatten(x)
return x
if pre_logits:
return x.flatten(1)
else:
x = self.fc(x)
return self.flatten(x)

@ -97,7 +97,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
class EvoNorm2dB0(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-3, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
@ -237,7 +237,7 @@ class EvoNorm2dS0(nn.Module):
class EvoNorm2dS0a(EvoNorm2dS0):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps)
@ -290,7 +290,7 @@ class EvoNorm2dS1(nn.Module):
class EvoNorm2dS1a(EvoNorm2dS1):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
@ -338,7 +338,7 @@ class EvoNorm2dS2(nn.Module):
class EvoNorm2dS2a(EvoNorm2dS2):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)

@ -0,0 +1,207 @@
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def inv_freq_bands(
num_bands: int,
temperature: float = 100000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
linear_bands: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
concat_out=False, device=device, dtype=dtype)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)

@ -32,7 +32,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import to_ntuple, get_act_layer
from .vision_transformer import trunc_normal_
from .registry import register_model
@ -65,6 +65,8 @@ default_cfgs = dict(
levit_384=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
),
levit_256d=_cfg(url='', classifier='head.l'),
)
model_cfgs = dict(
@ -78,6 +80,9 @@ model_cfgs = dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
levit_384=dict(
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
levit_256d=dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6)),
)
__all__ = ['Levit']
@ -113,15 +118,21 @@ def levit_384(pretrained=False, use_conv=False, **kwargs):
'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_256d(pretrained=False, use_conv=False, **kwargs):
return create_levit(
'levit_256d', pretrained=pretrained, use_conv=use_conv, distilled=False, **kwargs)
class ConvNorm(nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = nn.BatchNorm2d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
self.add_module('c', nn.Conv2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
self.add_module('bn', nn.BatchNorm2d(out_chs))
nn.init.constant_(self.bn.weight, bn_weight_init)
@torch.no_grad()
def fuse(self):
@ -138,13 +149,12 @@ class ConvNorm(nn.Sequential):
class LinearNorm(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
def __init__(self, in_features, out_features, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', nn.Linear(a, b, bias=False))
bn = nn.BatchNorm1d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
self.add_module('c', nn.Linear(in_features, out_features, bias=False))
self.add_module('bn', nn.BatchNorm1d(out_features))
nn.init.constant_(self.bn.weight, bn_weight_init)
@torch.no_grad()
def fuse(self):
@ -163,14 +173,14 @@ class LinearNorm(nn.Sequential):
class NormLinear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
def __init__(self, in_features, out_features, bias=True, std=0.02):
super().__init__()
self.add_module('bn', nn.BatchNorm1d(a))
l = nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
nn.init.constant_(l.bias, 0)
self.add_module('l', l)
self.add_module('bn', nn.BatchNorm1d(in_features))
self.add_module('l', nn.Linear(in_features, out_features, bias=bias))
trunc_normal_(self.l.weight, std=std)
if self.l.bias is not None:
nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
@ -231,34 +241,26 @@ class Attention(nn.Module):
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__()
ln_layer = ConvNorm if use_conv else LinearNorm
self.use_conv = use_conv
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
h = self.dh + nh_kd * 2
self.qkv = ln_layer(dim, h, resolution=resolution)
self.key_attn_dim = key_dim * num_heads
self.val_dim = int(attn_ratio * key_dim)
self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, resolution=resolution)
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
ln_layer(self.val_attn_dim, dim, bn_weight_init=0, resolution=resolution)
)
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution ** 2))
pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
self.register_buffer('attention_bias_idxs', rel_pos)
self.ab = {}
@torch.no_grad()
@ -279,7 +281,8 @@ class Attention(nn.Module):
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
q, k, v = self.qkv(x).view(
B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
@ -287,8 +290,8 @@ class Attention(nn.Module):
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q, k, v = self.qkv(x).view(
B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 3, 1)
v = v.permute(0, 2, 1, 3)
@ -296,7 +299,7 @@ class Attention(nn.Module):
attn = q @ k * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
x = self.proj(x)
return x
@ -306,17 +309,18 @@ class AttentionSubsample(nn.Module):
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
act_layer=None, stride=2, resolution=14, resolution_out=7, use_conv=False):
super().__init__()
self.stride = stride
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = self.d * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
self.key_attn_dim = key_dim * num_heads
self.val_dim = int(attn_ratio * key_dim)
self.val_attn_dim = self.val_dim * self.num_heads
self.resolution = resolution
self.resolution_out_area = resolution_out ** 2
self.use_conv = use_conv
if self.use_conv:
ln_layer = ConvNorm
@ -325,34 +329,25 @@ class AttentionSubsample(nn.Module):
ln_layer = LinearNorm
sub_layer = partial(Subsample, resolution=resolution)
h = self.dh + nh_kd
self.kv = ln_layer(in_dim, h, resolution=resolution)
self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, resolution=resolution)
self.q = nn.Sequential(
sub_layer(stride=stride),
ln_layer(in_dim, nh_kd, resolution=resolution_))
ln_layer(in_dim, self.key_attn_dim, resolution=resolution_out)
)
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, out_dim, resolution=resolution_))
ln_layer(self.val_attn_dim, out_dim, resolution=resolution_out)
)
self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.resolution ** 2))
k_pos = torch.stack(torch.meshgrid(torch.arange(resolution), torch.arange(resolution))).flatten(1)
q_pos = torch.stack(torch.meshgrid(
torch.arange(0, resolution, step=stride),
torch.arange(0, resolution, step=stride))).flatten(1)
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution) + rel_pos[1]
self.register_buffer('attention_bias_idxs', rel_pos)
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = {} # per-device attention_biases cache
@torch.no_grad()
@ -373,24 +368,24 @@ class AttentionSubsample(nn.Module):
def forward(self, x):
if self.use_conv:
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_out_area)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution, self.resolution)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
k = k.permute(0, 2, 3, 1) # BHCN
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
q = self.q(x).view(B, self.resolution_out_area, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
x = self.proj(x)
return x
@ -418,35 +413,37 @@ class Levit(nn.Module):
down_ops=None,
act_layer='hard_swish',
attn_act_layer='hard_swish',
distillation=True,
use_conv=False,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.):
super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer)
ln_layer = ConvNorm if use_conv else LinearNorm
self.use_conv = use_conv
if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.
assert img_size[0] == img_size[1]
img_size = img_size[0]
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
N = len(embed_dim)
assert len(depth) == len(num_heads) == N
key_dim = to_ntuple(N)(key_dim)
attn_ratio = to_ntuple(N)(attn_ratio)
mlp_ratio = to_ntuple(N)(mlp_ratio)
self.grad_checkpointing = False
num_stages = len(embed_dim)
assert len(depth) == len(num_heads) == num_stages
key_dim = to_ntuple(num_stages)(key_dim)
attn_ratio = to_ntuple(num_stages)(attn_ratio)
mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
down_ops = down_ops or (
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
('',)
)
self.distillation = distillation
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
@ -471,13 +468,13 @@ class Levit(nn.Module):
), drop_path_rate))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
resolution_out = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_, use_conv=use_conv))
resolution = resolution_
resolution=resolution, resolution_out=resolution_out, use_conv=use_conv))
resolution = resolution_out
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
@ -490,52 +487,87 @@ class Levit(nn.Module):
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
if self.head_dist is None:
return self.head
else:
return self.head, self.head_dist
return self.head
def reset_classifier(self, num_classes, global_pool='', distillation=None):
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation is not None:
self.distillation = distillation
if self.distillation:
self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
def forward_features(self, x):
x = self.patch_embed(x)
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting():
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
else:
x = self.head(x)
x = self.forward_head(x)
return x
class LevitDistilled(Levit):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False
@torch.jit.ignore
def get_classifier(self):
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def set_distilled_training(self, enable=True):
self.distilled_training = enable
def forward_head(self, x):
if self.global_pool == 'avg':
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
x, x_dist = self.head(x), self.head_dist(x)
if self.distilled_training and self.training and not torch.jit.is_scripting():
# only return separate classification predictions when training in distilled mode
return x, x_dist
else:
# during standard train/finetune, inference average the classifier predictions
return (x + x_dist) / 2
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
# For deit models
@ -547,16 +579,14 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict
def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
def create_levit(variant, pretrained=False, distilled=True, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cfg = dict(**model_cfgs[variant], **kwargs)
model = build_model_with_cfg(
Levit, variant, pretrained,
LevitDistilled if distilled else Levit, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**model_cfg)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -46,7 +46,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from .registry import register_model
@ -260,10 +260,13 @@ class MlpMixer(nn.Module):
drop_path_rate=0.,
nlhb=False,
stem_norm=False,
global_pool='avg',
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False
self.stem = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
@ -279,26 +282,46 @@ class MlpMixer(nn.Module):
self.init_weights(nlhb=nlhb)
@torch.jit.ignore
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x.mean(dim=1)
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head(x)
return x

@ -18,7 +18,7 @@ from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer
from .registry import register_model
@ -27,7 +27,7 @@ __all__ = ['MobileNetV3', 'MobileNetV3Features']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
@ -88,7 +88,7 @@ default_cfgs = {
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_g': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95),
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
"lcnet_035": _cfg(),
"lcnet_050": _cfg(
@ -134,6 +134,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem
if not fix_stem:
@ -166,6 +167,18 @@ class MobileNetV3(nn.Module):
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^conv_stem|bn1',
blocks=r'^blocks.(\d+)' if coarse else r'^blocks.(\d+).(\d+)'
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -179,18 +192,28 @@ class MobileNetV3(nn.Module):
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x, flatten=True)
else:
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
if pre_logits:
return x.flatten(1)
else:
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward(self, x):
x = self.forward_features(x)
x = self.flatten(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
x = self.forward_head(x)
return x
class MobileNetV3Features(nn.Module):

@ -20,6 +20,7 @@ from torch import nn
import torch.nn.functional as F
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
from .fx_features import register_notrace_module
from .layers import to_2tuple, make_divisible
from .vision_transformer import Block as TransformerBlock
from .helpers import build_model_with_cfg
@ -139,6 +140,7 @@ model_cfgs = dict(
)
@register_notrace_module
class MobileViTBlock(nn.Module):
""" MobileViT block
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
@ -206,7 +208,7 @@ class MobileViTBlock(nn.Module):
# Unfold (feature map -> patches)
patch_h, patch_w = self.patch_size
B, C, H, W = x.shape
new_h, new_w = int(math.ceil(H / patch_h) * patch_h), int(math.ceil(W / patch_w) * patch_w)
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
num_patches = num_patch_h * num_patch_w # N
interpolate = False

@ -407,8 +407,9 @@ class ReductionCell1(nn.Module):
class NASNetALarge(nn.Module):
"""NASNetALarge (6 @ 4032) """
def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
def __init__(
self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
super(NASNetALarge, self).__init__()
self.num_classes = num_classes
self.stem_size = stem_size
@ -503,6 +504,23 @@ class NASNetALarge(nn.Module):
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^conv0|cell_stem_[01]',
blocks=[
(r'^cell_(\d+)', None),
(r'^reduction_cell_0', (6,)),
(r'^reduction_cell_1', (12,)),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.last_linear
@ -542,14 +560,18 @@ class NASNetALarge(nn.Module):
x = self.act(x_cell_17)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(

@ -26,7 +26,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple
@ -179,6 +179,8 @@ class NestLevel(nn.Module):
norm_layer=None, act_layer=None, pad_type=''):
super().__init__()
self.block_size = block_size
self.grad_checkpointing = False
self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
if prev_embed_dim is not None:
@ -204,7 +206,10 @@ class NestLevel(nn.Module):
x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
x = blockify(x, self.block_size) # (B, T, N, C')
x = x + self.pos_embed
x = self.transformer_encoder(x) # (B, T, N, C')
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.transformer_encoder, x)
else:
x = self.transformer_encoder(x) # (B, T, N, C')
x = deblockify(x, self.block_size) # (B, H', W', C')
# Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
return x.permute(0, 3, 1, 2) # (B, C, H', W')
@ -217,10 +222,12 @@ class Nest(nn.Module):
- https://arxiv.org/abs/2105.12723
"""
def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
pad_type='', weight_init='', global_pool='avg'):
def __init__(
self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
pad_type='', weight_init='', global_pool='avg'
):
"""
Args:
img_size (int, tuple): input image size
@ -310,6 +317,7 @@ class Nest(nn.Module):
self.init_weights(weight_init)
@torch.jit.ignore
def init_weights(self, mode=''):
assert mode in ('nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
@ -321,6 +329,24 @@ class Nest(nn.Module):
def no_weight_decay(self):
return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embed', # stem and embed
blocks=[
(r'^levels.(\d+)' if coarse else r'^levels.(\d+).transformer_encoder.(\d+)', None),
(r'^levels.(\d+).(?:pool|pos_embed)', (0,)),
(r'^norm', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for l in self.levels:
l.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
@ -330,22 +356,22 @@ class Nest(nn.Module):
self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
""" x shape (B, C, H, W)
"""
x = self.patch_embed(x)
x = self.levels(x)
# Layer norm done over channel dim only (to NHWC and back)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
def forward(self, x):
""" x shape (B, C, H, W)
"""
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.head(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
@ -364,9 +390,6 @@ def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.)
trunc_normal_(module.weight, std=.02, a=-2, b=2)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def resize_pos_embed(posemb, posemb_new):

@ -27,7 +27,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
get_act_layer, get_act_fn, get_attn, make_divisible
@ -84,23 +84,6 @@ default_cfgs = dict(
nfnet_f7=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_f0s=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
nfnet_f1s=_dcfg(
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
nfnet_f2s=_dcfg(
url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
nfnet_f3s=_dcfg(
url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
nfnet_f4s=_dcfg(
url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
nfnet_f5s=_dcfg(
url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
nfnet_f6s=_dcfg(
url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
nfnet_f7s=_dcfg(
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
nfnet_l0=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth',
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
@ -222,7 +205,7 @@ model_cfgs = dict(
dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
# NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU)
# NFNet-F models w/ GELU
nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
@ -232,16 +215,6 @@ model_cfgs = dict(
nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
# NFNet-F models w/ SiLU (much faster in PyTorch)
nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
# Experimental 'light' versions of NFNet-F that are little leaner
nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
@ -477,11 +450,15 @@ class NormFreeNet(nn.Module):
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
"""
def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.):
def __init__(
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.
):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
@ -568,6 +545,22 @@ class NormFreeNet(nn.Module):
if m.bias is not None:
nn.init.zeros_(m.bias)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=[
(r'^stages.(\d+)' if coarse else r'^stages.(\d+).(\d+)', None),
(r'^final_conv', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -576,14 +569,20 @@ class NormFreeNet(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.final_conv(x)
x = self.final_act(x)
return x
def forward_head(self, x):
return self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x
@ -732,78 +731,6 @@ def nfnet_f7(pretrained=False, **kwargs):
return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f0s(pretrained=False, **kwargs):
""" NFNet-F0 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f1s(pretrained=False, **kwargs):
""" NFNet-F1 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f2s(pretrained=False, **kwargs):
""" NFNet-F2 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f3s(pretrained=False, **kwargs):
""" NFNet-F3 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f4s(pretrained=False, **kwargs):
""" NFNet-F4 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f5s(pretrained=False, **kwargs):
""" NFNet-F5 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f6s(pretrained=False, **kwargs):
""" NFNet-F6 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_f7s(pretrained=False, **kwargs):
""" NFNet-F7 w/ SiLU
`High-Performance Large-Scale Image Recognition Without Normalization`
- https://arxiv.org/abs/2102.06171
"""
return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
@register_model
def nfnet_l0(pretrained=False, **kwargs):
""" NFNet-L0b w/ SiLU

@ -148,9 +148,10 @@ class PoolingVisionTransformer(nn.Module):
- https://arxiv.org/abs/2103.16302
"""
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, distilled=False,
mlp_ratio, num_classes=1000, in_chans=3, distilled=False, global_pool='token',
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
super(PoolingVisionTransformer, self).__init__()
assert global_pool in ('token',)
padding = 0
img_size = to_2tuple(img_size)
@ -161,6 +162,7 @@ class PoolingVisionTransformer(nn.Module):
self.base_dims = base_dims
self.heads = heads
self.num_classes = num_classes
self.global_pool = global_pool
self.num_tokens = 2 if distilled else 1
self.patch_size = patch_size
@ -205,13 +207,17 @@ class PoolingVisionTransformer(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
def get_classifier(self):
if self.head_dist is not None:
return self.head, self.head_dist
else:
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.head_dist is not None:

@ -296,6 +296,15 @@ class PNASNet5Large(nn.Module):
self.global_pool, self.last_linear = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(stem=r'^conv_0|cell_stem_[01]', blocks=r'^cell_(\d+)')
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.last_linear
@ -323,12 +332,15 @@ class PNASNet5Large(nn.Module):
x = self.act(x_cell_11)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, training=self.training)
x = self.last_linear(x)
return x if pre_logits else self.last_linear(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -0,0 +1,322 @@
""" PoolFormer implementation
Paper: `PoolFormer: MetaFormer is Actually What You Need for Vision` - https://arxiv.org/abs/2111.11418
Code adapted from official impl at https://github.com/sail-sg/poolformer, original copyright in comment below
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
poolformer_s12=_cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar',
crop_pct=0.9),
poolformer_s24=_cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar',
crop_pct=0.9),
poolformer_s36=_cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar',
crop_pct=0.9),
poolformer_m36=_cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar',
crop_pct=0.95),
poolformer_m48=_cfg(
url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar',
crop_pct=0.95),
)
class PatchEmbed(nn.Module):
""" Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, in_chs=3, embed_dim=768, patch_size=16, stride=16, padding=0, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class GroupNorm1(nn.GroupNorm):
""" Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Pooling(nn.Module):
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
def forward(self, x):
return self.pool(x) - x
class PoolFormerBlock(nn.Module):
"""
Args:
dim: embedding dim
pool_size: pooling size
mlp_ratio: mlp expansion ratio
act_layer: activation
norm_layer: normalization
drop: dropout rate
drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382
use_layer_scale, --layer_scale_init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239
"""
def __init__(
self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm1,
drop=0., drop_path=0., layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = norm_layer(dim)
self.token_mixer = Pooling(pool_size=pool_size)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if layer_scale_init_value:
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
else:
self.layer_scale_1 = None
self.layer_scale_2 = None
def forward(self, x):
if self.layer_scale_1 is not None:
x = x + self.drop_path1(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x)))
x = x + self.drop_path2(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path1(self.token_mixer(self.norm1(x)))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
def basic_blocks(
dim, index, layers,
pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm1,
drop_rate=.0, drop_path_rate=0.,
layer_scale_init_value=1e-5,
):
""" generate PoolFormer blocks for a stage """
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PoolFormerBlock(
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)
return blocks
class PoolFormer(nn.Module):
""" PoolFormer
"""
def __init__(
self,
layers,
embed_dims=(64, 128, 320, 512),
mlp_ratios=(4, 4, 4, 4),
downsamples=(True, True, True, True),
pool_size=3,
in_chans=3,
num_classes=1000,
global_pool='avg',
norm_layer=GroupNorm1,
act_layer=nn.GELU,
in_patch_size=7,
in_stride=4,
in_pad=2,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0., drop_path_rate=0.,
layer_scale_init_value=1e-5,
**kwargs):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = embed_dims[-1]
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
in_chs=in_chans, embed_dim=embed_dims[0])
# set the main block in network
network = []
for i in range(len(layers)):
network.append(basic_blocks(
embed_dims[i], i, layers,
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate, drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value)
)
if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]):
# downsampling between stages
network.append(PatchEmbed(
in_chs=embed_dims[i], embed_dim=embed_dims[i + 1],
patch_size=down_patch_size, stride=down_stride, padding=down_pad)
)
self.network = nn.Sequential(*network)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
# init for classification
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^patch_embed', # stem and embed
blocks=[
(r'^network\.(\d+)\.(\d+)', None),
(r'^network\.(\d+)', (0,)),
(r'^norm', (99999,))
],
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = self.network(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean([-2, -1])
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_poolformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(PoolFormer, variant, pretrained, **kwargs)
return model
@register_model
def poolformer_s12(pretrained=False, **kwargs):
""" PoolFormer-S12 model, Params: 12M """
model = _create_poolformer('poolformer_s12', pretrained=pretrained, layers=(2, 2, 6, 2), **kwargs)
return model
@register_model
def poolformer_s24(pretrained=False, **kwargs):
""" PoolFormer-S24 model, Params: 21M """
model = _create_poolformer('poolformer_s24', pretrained=pretrained, layers=(4, 4, 12, 4), **kwargs)
return model
@register_model
def poolformer_s36(pretrained=False, **kwargs):
""" PoolFormer-S36 model, Params: 31M """
model = _create_poolformer(
'poolformer_s36', pretrained=pretrained, layers=(6, 6, 18, 6), layer_scale_init_value=1e-6, **kwargs)
return model
@register_model
def poolformer_m36(pretrained=False, **kwargs):
""" PoolFormer-M36 model, Params: 56M """
layers = (6, 6, 18, 6)
embed_dims = (96, 192, 384, 768)
model = _create_poolformer(
'poolformer_m36', pretrained=pretrained, layers=layers, embed_dims=embed_dims,
layer_scale_init_value=1e-6, **kwargs)
return model
@register_model
def poolformer_m48(pretrained=False, **kwargs):
""" PoolFormer-M48 model, Params: 73M """
layers = (8, 8, 24, 8)
embed_dims = (96, 192, 384, 768)
model = _create_poolformer(
'poolformer_m48', pretrained=pretrained, layers=layers, embed_dims=embed_dims,
layer_scale_init_value=1e-6, **kwargs)
return model

@ -19,10 +19,11 @@ from functools import partial
from typing import Optional, Union, Callable
import numpy as np
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from .layers import get_act_layer, get_norm_act_layer, create_conv2d
from .registry import register_model
@ -80,14 +81,13 @@ model_cfgs = dict(
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
# regnetw = 'preact regnet z'
regnetw_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, preact=True, num_features=1536, act_layer='silu',
),
regnetv_064=RegNetCfg(
depth=25, w0=112, wa=33.22, wm=2.27, group_size=72, se_ratio=0.25, preact=True, act_layer='silu',
downsample='avg'),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
@ -95,6 +95,10 @@ model_cfgs = dict(
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=0, act_layer='silu',
),
regnetz_040h=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
@ -144,10 +148,11 @@ default_cfgs = dict(
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(url='', first_conv='stem'),
regnetw_040=_cfg(url='', first_conv='stem', input_size=(3, 256, 256), pool_size=(8, 8)),
regnetv_064=_cfg(url='', first_conv='stem'),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
regnetz_040h=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
)
@ -326,6 +331,8 @@ class RegStage(nn.Module):
self, depth, in_chs, out_chs, stride, dilation,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
super(RegStage, self).__init__()
self.grad_checkpointing = False
first_dilation = 1 if dilation in (1, 2) else 2
for i in range(depth):
block_stride = stride if i == 0 else 1
@ -341,8 +348,11 @@ class RegStage(nn.Module):
first_dilation = dilation
def forward(self, x):
for block in self.children():
x = block(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.children(), x)
else:
for block in self.children():
x = block(x)
return x
@ -375,6 +385,7 @@ class RegNet(nn.Module):
curr_stride = 2
per_stage_args, common_args = self._get_stage_args(
cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
assert len(per_stage_args) == 4
block_fn = PreBottleneck if cfg.preact else Bottleneck
for i, stage_args in enumerate(per_stage_args):
stage_name = "s{}".format(i + 1)
@ -429,6 +440,19 @@ class RegNet(nn.Module):
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
return per_stage_args, common_args
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in list(self.children())[1:-1]:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -436,13 +460,20 @@ class RegNet(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
for block in list(self.children())[:-1]:
x = block(x)
x = self.stem(x)
x = self.s1(x)
x = self.s2(x)
x = self.s3(x)
x = self.s4(x)
x = self.final_conv(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
for block in self.children():
x = block(x)
x = self.forward_features(x)
x = self.forward_head(x)
return x
@ -634,9 +665,9 @@ def regnetv_040(pretrained=False, **kwargs):
@register_model
def regnetw_040(pretrained=False, **kwargs):
def regnetv_064(pretrained=False, **kwargs):
""""""
return _create_regnet('regnetw_040', pretrained, **kwargs)
return _create_regnet('regnetv_064', pretrained, **kwargs)
@register_model
@ -655,3 +686,12 @@ def regnetz_040(pretrained=False, **kwargs):
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs)
@register_model
def regnetz_040h(pretrained=False, **kwargs):
"""RegNetZ-4.0GF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040h', pretrained, zero_init_last=False, **kwargs)

@ -50,9 +50,10 @@ class Bottle2neck(nn.Module):
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
def __init__(
self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
@ -87,7 +88,7 @@ class Bottle2neck(nn.Module):
self.relu = act_layer(inplace=True)
self.downsample = downsample
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
@ -110,8 +111,7 @@ class Bottle2neck(nn.Module):
sp = self.relu(sp)
spo.append(sp)
if self.scale > 1:
if self.pool is not None:
# self.is_first == True, None check for torchscript
if self.pool is not None: # self.is_first == True, None check for torchscript
spo.append(self.pool(spx[-1]))
else:
spo.append(spx[-1])

@ -57,10 +57,11 @@ class ResNestBottleneck(nn.Module):
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResNestBottleneck, self).__init__()
assert reduce_first == 1 # not supported
assert attn_layer is None # not supported
@ -102,7 +103,7 @@ class ResNestBottleneck(nn.Module):
self.act3 = act_layer(inplace=True)
self.downsample = downsample
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier
from .registry import register_model
@ -105,7 +105,9 @@ default_cfgs = {
first_conv='conv1.0'),
'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(url=''),
'resnext101_64x4d': _cfg(
url='',
interpolation='bicubic', crop_pct=1.0, test_input_size=(3, 288, 288)),
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
# ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags
@ -345,7 +347,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation
self.drop_path = drop_path
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
@ -411,7 +413,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation
self.drop_path = drop_path
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
@ -600,12 +602,13 @@ class ResNet(nn.Module):
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
drop_block_rate=0., global_pool='avg', zero_init_last=True, block_args=None):
super(ResNet, self).__init__()
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
self.num_classes = num_classes
self.drop_rate = drop_rate
super(ResNet, self).__init__()
self.grad_checkpointing = False
# Stem
deep_stem = 'deep' in stem_type
@ -632,7 +635,7 @@ class ResNet(nn.Module):
if replace_stem_pool:
self.maxpool = nn.Sequential(*filter(None, [
nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
create_aa(aa_layer, channels=inplanes, stride=2),
create_aa(aa_layer, channels=inplanes, stride=2) if aa_layer is not None else None,
norm_layer(inplanes),
act_layer(inplace=True)
]))
@ -662,22 +665,33 @@ class ResNet(nn.Module):
self.num_features = 512 * block.expansion
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(zero_init_last_bn=zero_init_last_bn)
self.init_weights(zero_init_last=zero_init_last)
def init_weights(self, zero_init_last_bn=True):
@torch.jit.ignore
def init_weights(self, zero_init_last=True):
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
if zero_init_last:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
if hasattr(m, 'zero_init_last'):
m.zero_init_last()
def get_classifier(self):
return self.fc
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self, name_only=False):
return 'fc' if name_only else self.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
@ -689,10 +703,13 @@ class ResNet(nn.Module):
x = self.act1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
else:
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):

@ -36,10 +36,9 @@ import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .registry import register_model
from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\
EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\
from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, EvoNorm2dS1, FilterResponseNormTlu2d,\
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
@ -280,9 +279,10 @@ class DownsampleAvg(nn.Module):
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
def __init__(
self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
@ -397,7 +397,9 @@ class ResNetV2(nn.Module):
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
self.init_weights(zero_init_last=zero_init_last)
self.grad_checkpointing = False
@torch.jit.ignore
def init_weights(self, zero_init_last=True):
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -405,6 +407,22 @@ class ResNetV2(nn.Module):
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else [
(r'^stages.(\d+).blocks.(\d+)', None),
(r'^norm', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -415,13 +433,19 @@ class ResNetV2(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x, flatten=True)
else:
x = self.stages(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -16,7 +16,7 @@ from functools import partial
from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -54,8 +54,9 @@ SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
class LinearBottleneck(nn.Module):
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
act_layer='swish', dw_act_layer='relu6', drop_path=None):
def __init__(
self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
act_layer='swish', dw_act_layer='relu6', drop_path=None):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_chs <= out_chs
self.in_channels = in_chs
@ -143,12 +144,15 @@ def _build_blocks(
class ReXNetV1(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.):
def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.
):
super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
@ -165,6 +169,19 @@ class ReXNetV1(nn.Module):
efficientnet_init_weights(self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem',
blocks=r'^features.(\d+)',
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -173,12 +190,18 @@ class ReXNetV1(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.features(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.features, x, flatten=True)
else:
x = self.features(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -174,6 +174,19 @@ class SelecSLS(nn.Module):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^features\.(\d+)',
blocks_head=r'^head'
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.fc
@ -187,12 +200,15 @@ class SelecSLS(nn.Module):
x = self.head(self.from_seq(x))
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x)
return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -14,6 +14,7 @@ support for extras like dilation, switchable BN/activations, feature extraction,
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -120,8 +121,7 @@ class SEBottleneck(Bottleneck):
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes * 2)
@ -129,8 +129,7 @@ class SEBottleneck(Bottleneck):
planes * 2, planes * 4, kernel_size=3, stride=stride,
padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes * 4)
self.conv3 = nn.Conv2d(
planes * 4, planes * 4, kernel_size=1, bias=False)
self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
@ -146,14 +145,11 @@ class SEResNetBottleneck(Bottleneck):
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
super(SEResNetBottleneck, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, bias=False, stride=stride)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
@ -169,15 +165,12 @@ class SEResNeXtBottleneck(Bottleneck):
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None, base_width=4):
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4):
super(SEResNeXtBottleneck, self).__init__()
width = math.floor(planes * (base_width / 64)) * groups
self.conv1 = nn.Conv2d(
inplanes, width, kernel_size=1, bias=False, stride=1)
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(
width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
@ -192,11 +185,9 @@ class SEResNetBlock(nn.Module):
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
super(SEResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes, reduction=reduction)
@ -225,9 +216,10 @@ class SEResNetBlock(nn.Module):
class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1,
downsample_padding=0, num_classes=1000, global_pool='avg'):
def __init__(
self, block, layers, groups, reduction, drop_rate=0.2,
in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1,
downsample_padding=0, num_classes=1000, global_pool='avg'):
"""
Parameters
----------
@ -366,6 +358,16 @@ class SENet(nn.Module):
return nn.Sequential(*layers)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^layer0', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+).(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.last_linear
@ -383,16 +385,15 @@ class SENet(nn.Module):
x = self.layer4(x)
return x
def logits(self, x):
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)
return x
return x if pre_logits else self.last_linear(x)
def forward(self, x):
x = self.forward_features(x)
x = self.logits(x)
x = self.forward_head(x)
return x

@ -46,9 +46,10 @@ default_cfgs = {
class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -69,7 +70,7 @@ class SelectiveKernelBasic(nn.Module):
self.downsample = downsample
self.drop_path = drop_path
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
@ -90,10 +91,10 @@ class SelectiveKernelBasic(nn.Module):
class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -113,7 +114,7 @@ class SelectiveKernelBottleneck(nn.Module):
self.downsample = downsample
self.drop_path = drop_path
def zero_init_last_bn(self):
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
@ -146,7 +147,7 @@ def skresnet18(pretrained=False, **kwargs):
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
zero_init_last=False, **kwargs)
return _create_skresnet('skresnet18', pretrained, **model_args)
@ -160,7 +161,7 @@ def skresnet34(pretrained=False, **kwargs):
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
zero_init_last=False, **kwargs)
return _create_skresnet('skresnet34', pretrained, **model_args)
@ -174,7 +175,7 @@ def skresnet50(pretrained=False, **kwargs):
sk_kwargs = dict(split_input=True)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
zero_init_last=False, **kwargs)
return _create_skresnet('skresnet50', pretrained, **model_args)
@ -188,7 +189,7 @@ def skresnet50d(pretrained=False, **kwargs):
sk_kwargs = dict(split_input=True)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs)
return _create_skresnet('skresnet50d', pretrained, **model_args)
@ -200,6 +201,6 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs)
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)

@ -19,14 +19,13 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
_logger = logging.getLogger(__name__)
@ -85,6 +84,15 @@ default_cfgs = {
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_s3_tiny_224': _cfg(
url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-T.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-S.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/silent-chen/AutoFormerV2-model-zoo/releases/download/v1.0.0/S3-B.pth'
)
}
@ -121,53 +129,64 @@ def window_reverse(windows, window_size: int, H: int, W: int):
return x
def get_relative_position_index(win_h, win_w):
# get pair-wise relative position index for each token inside the window
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
relative_coords[:, :, 1] += win_w - 1
relative_coords[:, :, 0] *= 2 * win_w - 1
return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
head_dim (int): Number of channels per head (dim // num_heads if not set)
window_size (tuple[int]): The height and width of the window.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.window_size = to_2tuple(window_size) # Wh, Ww
win_h, win_w = self.window_size
self.window_area = win_h * win_w
self.num_heads = num_heads
head_dim = dim // num_heads
head_dim = head_dim or dim // num_heads
attn_dim = head_dim * num_heads
self.scale = head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj = nn.Linear(attn_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def _get_rel_pos_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
return relative_position_bias.unsqueeze(0)
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
@ -175,20 +194,16 @@ class WindowAttention(nn.Module):
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
attn = attn + self._get_rel_pos_bias()
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
num_win = mask.shape[0]
attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
@ -196,7 +211,7 @@ class WindowAttention(nn.Module):
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -208,8 +223,9 @@ class SwinTransformerBlock(nn.Module):
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
num_heads (int): Number of attention heads.
head_dim (int): Enforce the number of channels per head
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
@ -220,13 +236,13 @@ class SwinTransformerBlock(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
def __init__(
self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
@ -238,31 +254,29 @@ class SwinTransformerBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
for h in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
for w in (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
@ -287,11 +301,11 @@ class SwinTransformerBlock(nn.Module):
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
@ -320,12 +334,13 @@ class PatchMerging(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.out_dim = out_dim or 2 * dim
self.norm = norm_layer(4 * dim)
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
def forward(self, x):
"""
@ -350,15 +365,6 @@ class PatchMerging(nn.Module):
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@ -368,6 +374,7 @@ class BasicLayer(nn.Module):
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
head_dim (int): Channels per head (dim // num_heads if not set)
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
@ -376,47 +383,43 @@ class BasicLayer(nn.Module):
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
def __init__(
self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
self.grad_checkpointing = False
# build blocks
self.blocks = nn.ModuleList([
self.blocks = nn.Sequential(*[
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim,
window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if not torch.jit.is_scripting() and self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class SwinTransformer(nn.Module):
r""" Swin Transformer
@ -431,6 +434,7 @@ class SwinTransformer(nn.Module):
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
head_dim (int, tuple(int)):
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
@ -440,31 +444,26 @@ class SwinTransformer(nn.Module):
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None,
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, weight_init='', **kwargs):
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs):
super().__init__()
assert global_pool in ('', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
norm_layer=norm_layer if patch_norm else None)
num_patches = self.patch_embed.num_patches
self.patch_grid = self.patch_embed.grid_size
@ -473,52 +472,80 @@ class SwinTransformer(nn.Module):
self.pos_drop = nn.Dropout(p=drop_rate)
# build layers
if not isinstance(embed_dim, (tuple, list)):
embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
embed_out_dim = embed_dim[1:] + [None]
head_dim = to_ntuple(self.num_layers)(head_dim)
window_size = to_ntuple(self.num_layers)(window_size)
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
layers = []
for i_layer in range(self.num_layers):
for i in range(self.num_layers):
layers += [BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
dim=embed_dim[i],
out_dim=embed_out_dim[i],
input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)),
depth=depths[i],
num_heads=num_heads[i],
head_dim=head_dim[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
]
downsample=PatchMerging if (i < self.num_layers - 1) else None
)]
self.layers = nn.Sequential(*layers)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
if weight_init != 'skip':
self.init_weights(weight_init)
@torch.jit.ignore
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
assert mode in ('jax', 'jax_nlhb', 'moco', '')
if self.absolute_pos_embed is not None:
trunc_normal_(self.absolute_pos_embed, std=.02)
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self)
named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
nwd = {'absolute_pos_embed'}
for n, _ in self.named_parameters():
if 'relative_position_bias_table' in n:
nwd.add(n)
return nwd
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers.(\d+)' if coarse else [
(r'^layers.(\d+).downsample', (0,)),
(r'^layers.(\d+).\w+.(\d+)', None),
(r'^norm', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for l in self.layers:
l.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
self.global_pool = global_pool
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -530,11 +557,14 @@ class SwinTransformer(nn.Module):
x = self.norm(x) # B L C
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
@ -547,7 +577,6 @@ def _create_swin_transformer(variant, pretrained=False, **kwargs):
return model
@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
@ -636,3 +665,34 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
@register_model
def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)

@ -9,6 +9,7 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
@ -77,7 +78,8 @@ class Attention(nn.Module):
class Block(nn.Module):
""" TNT Block
"""
def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
def __init__(
self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
# Inner transformer
@ -153,12 +155,16 @@ class PixelEmbed(nn.Module):
class TNT(nn.Module):
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12,
num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
super().__init__()
assert global_pool in ('', 'token', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False
self.pixel_embed = PixelEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride)
@ -206,11 +212,29 @@ class TNT(nn.Module):
def no_weight_decay(self):
return {'patch_pos', 'pixel_pos', 'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
blocks=[
(r'^blocks.(\d+)', None),
(r'^norm', (99999,)),
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'token', 'avg')
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -222,16 +246,24 @@ class TNT(nn.Module):
patch_embed = patch_embed + self.patch_pos
patch_embed = self.pos_drop(patch_embed)
for blk in self.blocks:
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
if self.grad_checkpointing and not torch.jit.is_scripting():
for blk in self.blocks:
pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
else:
for blk in self.blocks:
pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
patch_embed = self.norm(patch_embed)
return patch_embed
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
x = self.forward_head(x)
return x

@ -107,8 +107,9 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True,
act_layer="leaky_relu", aa_layer=None):
def __init__(
self, inplanes, planes, stride=1, downsample=None, use_se=True,
act_layer="leaky_relu", aa_layer=None):
super(Bottleneck, self).__init__()
self.conv1 = conv2d_iabn(
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3)
@ -130,7 +131,7 @@ class Bottleneck(nn.Module):
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
self.relu = nn.ReLU(inplace=True)
self.act = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
@ -144,10 +145,9 @@ class Bottleneck(nn.Module):
out = self.conv2(out)
if self.se is not None:
out = self.se(out)
out = self.conv3(out)
out = out + shortcut # no inplace
out = self.relu(out)
out = self.act(out)
return out
@ -194,7 +194,7 @@ class TResNet(nn.Module):
self.num_features = (self.planes * 8) * Bottleneck.expansion
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
# model initilization
# model initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
@ -231,6 +231,16 @@ class TResNet(nn.Module):
block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer))
return nn.Sequential(*layers)
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^body.conv1', blocks=r'^body.layer(\d+)' if coarse else r'^body.layer(\d+).(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -241,9 +251,12 @@ class TResNet(nn.Module):
def forward_features(self, x):
return self.body(x)
def forward_head(self, x, pre_logits: bool = False):
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -198,8 +198,9 @@ class GlobalSubSampleAttn(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
def __init__(
self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
super().__init__()
self.norm1 = norm_layer(dim)
if ws is None:
@ -273,15 +274,17 @@ class Twins(nn.Module):
Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512),
num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None,
block_cls=Block):
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), depths=(3, 4, 6, 3),
sr_ratios=(8, 4, 2, 1), wss=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_cls=Block):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.depths = depths
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.grad_checkpointing = False
img_size = to_2tuple(img_size)
prev_chs = in_chans
@ -319,11 +322,34 @@ class Twins(nn.Module):
def no_weight_decay(self):
return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()])
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^patch_embeds.0', # stem and embed
blocks=[
(r'^(?:blocks|patch_embeds|pos_block).(\d+)', None),
('^norm', (99999,))
] if coarse else [
(r'^blocks.(\d+).(\d+)', None),
(r'^(?:patch_embeds|pos_block).(\d+)', (0,)),
(r'^norm', (99999,))
]
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def _init_weights(self, m):
@ -340,9 +366,6 @@ class Twins(nn.Module):
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward_features(self, x):
B = x.shape[0]
@ -359,10 +382,14 @@ class Twins(nn.Module):
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x.mean(dim=1)
x = self.head(x)
x = self.forward_head(x)
return x

@ -11,7 +11,7 @@ import torch.nn.functional as F
from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .fx_features import register_notrace_module
from .layers import ClassifierHead
from .registry import register_model
@ -25,7 +25,7 @@ __all__ = [
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.0', 'classifier': 'head.fc',
@ -56,8 +56,9 @@ cfgs: Dict[str, List[Union[str, int]]] = {
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None):
def __init__(
self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None):
super(ConvMlp, self).__init__()
self.input_kernel_size = kernel_size
mid_features = int(out_features * mlp_ratio)
@ -83,23 +84,25 @@ class ConvMlp(nn.Module):
class VGG(nn.Module):
def __init__(
self,
cfg: List[Any],
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
mlp_ratio: float = 1.0,
act_layer: nn.Module = nn.ReLU,
conv_layer: nn.Module = nn.Conv2d,
norm_layer: nn.Module = None,
global_pool: str = 'avg',
drop_rate: float = 0.,
self,
cfg: List[Any],
num_classes: int = 1000,
in_chans: int = 3,
output_stride: int = 32,
mlp_ratio: float = 1.0,
act_layer: nn.Module = nn.ReLU,
conv_layer: nn.Module = nn.Conv2d,
norm_layer: nn.Module = None,
global_pool: str = 'avg',
drop_rate: float = 0.,
) -> None:
super(VGG, self).__init__()
assert output_stride == 32
self.num_classes = num_classes
self.num_features = 4096
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.use_norm = norm_layer is not None
self.feature_info = []
prev_chs = in_chans
net_stride = 1
@ -121,6 +124,7 @@ class VGG(nn.Module):
prev_chs = v
self.features = nn.Sequential(*layers)
self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
self.pre_logits = ConvMlp(
prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio,
drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer)
@ -129,6 +133,16 @@ class VGG(nn.Module):
self._initialize_weights()
@torch.jit.ignore
def group_matcher(self, coarse=False):
# this treats BN layers as separate groups for bn variants, a lot of effort to fix that
return dict(stem=r'^features.0', blocks=r'^features.(\d+)')
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -139,12 +153,15 @@ class VGG(nn.Module):
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x)
x = self.pre_logits(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
x = self.pre_logits(x)
return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x
def _initialize_weights(self) -> None:

@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
from .registry import register_model
@ -41,8 +41,9 @@ default_cfgs = dict(
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
def __init__(
self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -99,7 +100,7 @@ class Attention(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2]
q, k, v = x.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
@ -113,9 +114,10 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
def __init__(
self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -128,9 +130,8 @@ class Block(nn.Module):
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = SpatialMlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop,
group=group, spatial_conv=spatial_conv) # new setting
def forward(self, x):
@ -141,10 +142,11 @@ class Block(nn.Module):
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
super().__init__()
img_size = to_2tuple(img_size)
self.num_classes = num_classes
@ -160,8 +162,9 @@ class Visformer(nn.Module):
self.stage_num1 = self.stage_num3 = depth // 3
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
self.pos_embed = pos_embed
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.grad_checkpointing = False
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# stage 1
if self.vit_stem:
self.stem = None
@ -194,7 +197,7 @@ class Visformer(nn.Module):
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([
self.stage1 = nn.Sequential(*[
Block(
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
@ -211,7 +214,7 @@ class Visformer(nn.Module):
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
self.stage2 = nn.ModuleList([
self.stage2 = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
@ -228,7 +231,7 @@ class Visformer(nn.Module):
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
self.stage3 = nn.ModuleList([
self.stage3 = nn.Sequential(*[
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
@ -255,12 +258,6 @@ class Visformer(nn.Module):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if self.conv_init:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
@ -269,6 +266,22 @@ class Visformer(nn.Module):
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
blocks=[
(r'^stage(\d+).(\d+)' if coarse else r'^stage(\d+).(\d+)', None),
(r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
(r'^norm', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
@ -283,36 +296,42 @@ class Visformer(nn.Module):
# stage 1
x = self.patch_embed1(x)
if self.pos_embed:
x = x + self.pos_embed1
x = self.pos_drop(x)
for b in self.stage1:
x = b(x)
x = self.pos_drop(x + self.pos_embed1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage1, x)
else:
x = self.stage1(x)
# stage 2
if not self.vit_stem:
x = self.patch_embed2(x)
if self.pos_embed:
x = x + self.pos_embed2
x = self.pos_drop(x)
for b in self.stage2:
x = b(x)
x = self.pos_drop(x + self.pos_embed2)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage2, x)
else:
x = self.stage2(x)
# stage3
if not self.vit_stem:
x = self.patch_embed3(x)
if self.pos_embed:
x = x + self.pos_embed3
x = self.pos_drop(x)
for b in self.stage3:
x = b(x)
x = self.pos_drop(x + self.pos_embed3)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage3, x)
else:
x = self.stage3(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
x = self.head(x)
x = self.forward_head(x)
return x

@ -27,9 +27,10 @@ from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model
@ -202,20 +203,23 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x + self.drop_path1(self.attn(self.norm1(x)))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
@ -227,8 +231,8 @@ class VisionTransformer(nn.Module):
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, global_pool='',
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
"""
@ -237,6 +241,7 @@ class VisionTransformer(nn.Module):
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
@ -252,12 +257,15 @@ class VisionTransformer(nn.Module):
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -301,17 +309,15 @@ class VisionTransformer(nn.Module):
self.pre_logits = nn.Identity()
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if 'jax' not in mode:
# init cls token to truncated normal if not following jax impl, jax impl is zero
trunc_normal_(self.cls_token, std=.02)
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self)
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
init_weights_vit_timm(m)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
@ -321,12 +327,26 @@ class VisionTransformer(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool='', representation_size=None):
def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None):
self.num_classes = num_classes
self.global_pool = global_pool
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
if representation_size is not None:
self._reset_representation(representation_size)
final_chs = self.representation_size if self.representation_size else self.embed_dim
@ -336,28 +356,36 @@ class VisionTransformer(nn.Module):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.global_pool == 'avg':
x = x[:, self.num_tokens:].mean(dim=1)
else:
x = x[:, 0]
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.pre_logits(x)
x = self.head(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
"""
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
""" ViT weight initialization, matching JAX (Flax) impl """
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
@ -366,25 +394,35 @@ def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0.,
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
if jax_impl:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
else:
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def init_weights_vit_moco(module: nn.Module, name: str = ''):
""" ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
if isinstance(module, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def get_init_weights_vit(mode='jax', head_bias: float = 0.):
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
return init_weights_vit_moco
else:
return init_weights_vit_timm
@torch.no_grad()

@ -0,0 +1,750 @@
""" Vision OutLOoker (VOLO) implementation
Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112
Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
# Copyright 2021 Sea Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.helpers import build_model_with_cfg
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'),
**kwargs
}
default_cfgs = {
'volo_d1_224': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar',
crop_pct=0.96),
'volo_d1_384': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar',
crop_pct=1.0, input_size=(3, 384, 384)),
'volo_d2_224': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar',
crop_pct=0.96),
'volo_d2_384': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar',
crop_pct=1.0, input_size=(3, 384, 384)),
'volo_d3_224': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar',
crop_pct=0.96),
'volo_d3_448': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar',
crop_pct=1.0, input_size=(3, 448, 448)),
'volo_d4_224': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar',
crop_pct=0.96),
'volo_d4_448': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar',
crop_pct=1.15, input_size=(3, 448, 448)),
'volo_d5_224': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar',
crop_pct=0.96),
'volo_d5_448': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar',
crop_pct=1.15, input_size=(3, 448, 448)),
'volo_d5_512': _cfg(
url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar',
crop_pct=1.15, input_size=(3, 512, 512)),
}
class OutlookAttention(nn.Module):
def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
head_dim = dim // num_heads
self.num_heads = num_heads
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.scale = head_dim ** -0.5
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
def forward(self, x):
B, H, W, C = x.shape
v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W
h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
v = self.unfold(v).reshape(
B, self.num_heads, C // self.num_heads,
self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H
attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
attn = self.attn(attn).reshape(
B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk
attn = attn * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)
x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
x = self.proj(x.permute(0, 2, 3, 1))
x = self.proj_drop(x)
return x
class Outlooker(nn.Module):
def __init__(
self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = OutlookAttention(
dim, num_heads, kernel_size=kernel_size,
padding=padding, stride=stride,
qkv_bias=qkv_bias, attn_drop=attn_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, H, W, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Transformer(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ClassAttention(nn.Module):
def __init__(
self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
if head_dim is not None:
self.head_dim = head_dim
else:
head_dim = dim // num_heads
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias)
self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(self.head_dim * self.num_heads, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)
q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
cls_embed = self.proj(cls_embed)
cls_embed = self.proj_drop(cls_embed)
return cls_embed
class ClassBlock(nn.Module):
def __init__(
self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttention(
dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
cls_embed = x[:, :1]
cls_embed = cls_embed + self.drop_path(self.attn(self.norm1(x)))
cls_embed = cls_embed + self.drop_path(self.mlp(self.norm2(cls_embed)))
return torch.cat([cls_embed, x[:, 1:]], dim=1)
def get_block(block_type, **kargs):
if block_type == 'ca':
return ClassBlock(**kargs)
def rand_bbox(size, lam, scale=1):
"""
get bounding box as token labeling (https://github.com/zihangJiang/TokenLabeling)
return: bounding box
"""
W = size[1] // scale
H = size[2] // scale
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
class PatchEmbed(nn.Module):
""" Image to Patch Embedding.
Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding
"""
def __init__(
self, img_size=224, stem_conv=False, stem_stride=1,
patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384):
super().__init__()
assert patch_size in [4, 8, 16]
if stem_conv:
self.conv = nn.Sequential(
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
)
else:
self.conv = None
self.proj = nn.Conv2d(
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride)
self.num_patches = (img_size // patch_size) * (img_size // patch_size)
def forward(self, x):
if self.conv is not None:
x = self.conv(x)
x = self.proj(x) # B, C, H, W
return x
class Downsample(nn.Module):
""" Image to Patch Embedding, downsampling between stage1 and stage2
"""
def __init__(self, in_embed_dim, out_embed_dim, patch_size=2):
super().__init__()
self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = x.permute(0, 3, 1, 2)
x = self.proj(x) # B, C, H, W
x = x.permute(0, 2, 3, 1)
return x
def outlooker_blocks(
block_fn, index, dim, layers, num_heads=1, kernel_size=3, padding=1, stride=2,
mlp_ratio=3., qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
"""
generate outlooker layer in stage1
return: outlooker layers
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(
block_fn(
dim, kernel_size=kernel_size, padding=padding,
stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr))
blocks = nn.Sequential(*blocks)
return blocks
def transformer_blocks(
block_fn, index, dim, layers, num_heads, mlp_ratio=3.,
qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
"""
generate transformer layers in stage2
return: transformer layers
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(
block_fn(
dim, num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
drop_path=block_dpr))
blocks = nn.Sequential(*blocks)
return blocks
class VOLO(nn.Module):
"""
Vision Outlooker, the main class of our model
"""
def __init__(
self,
layers,
img_size=224,
in_chans=3,
num_classes=1000,
global_pool='token',
patch_size=8,
stem_hidden_dim=64,
embed_dims=None,
num_heads=None,
downsamples=(True, False, False, False),
outlook_attention=(True, False, False, False),
mlp_ratio=3.0,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
post_layers=('ca', 'ca'),
use_aux_head=True,
use_mix_token=False,
pooling_scale=2,
):
super().__init__()
num_layers = len(layers)
mlp_ratio = to_ntuple(num_layers)(mlp_ratio)
img_size = to_2tuple(img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.mix_token = use_mix_token
self.pooling_scale = pooling_scale
self.num_features = embed_dims[-1]
if use_mix_token: # enable token mixing, see token labeling for details.
self.beta = 1.0
assert global_pool == 'token', "return all tokens if mix_token is enabled"
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
stem_conv=True, stem_stride=2, patch_size=patch_size,
in_chans=in_chans, hidden_dim=stem_hidden_dim,
embed_dim=embed_dims[0])
# inital positional encoding, we add positional encoding after outlooker blocks
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1]))
self.pos_drop = nn.Dropout(p=drop_rate)
# set the main block in network
network = []
for i in range(len(layers)):
if outlook_attention[i]:
# stage 1
stage = outlooker_blocks(
Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer)
network.append(stage)
else:
# stage 2
stage = transformer_blocks(
Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias,
drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
network.append(stage)
if downsamples[i]:
# downsampling between two stages
network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2))
self.network = nn.ModuleList(network)
# set post block, for example, class attention layers
self.post_network = None
if post_layers is not None:
self.post_network = nn.ModuleList(
[
get_block(
post_layers[i],
dim=embed_dims[-1],
num_heads=num_heads[-1],
mlp_ratio=mlp_ratio[-1],
qkv_bias=qkv_bias,
attn_drop=attn_drop_rate,
drop_path=0.,
norm_layer=norm_layer)
for i in range(len(post_layers))
])
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))
trunc_normal_(self.cls_token, std=.02)
# set output type
if use_aux_head:
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else:
self.aux_head = None
self.norm = norm_layer(self.num_features)
# Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[
(r'^network\.(\d+)\.(\d+)', None),
(r'^network\.(\d+)', (0,)),
],
blocks2=[
(r'^cls_token', (0,)),
(r'^post_network\.(\d+)', None),
(r'^norm', (99999,))
],
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
if self.aux_head is not None:
self.aux_head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_tokens(self, x):
for idx, block in enumerate(self.network):
if idx == 2:
# add positional encoding after outlooker blocks
x = x + self.pos_embed
x = self.pos_drop(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(block, x)
else:
x = block(x)
B, H, W, C = x.shape
x = x.reshape(B, -1, C)
return x
def forward_cls(self, x):
B, N, C = x.shape
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
for block in self.post_network:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(block, x)
else:
x = block(x)
return x
def forward_train(self, x):
""" A separate forward fn for training with mix_token (if a train script supports).
Combining multiple modes in as single forward with different return types is torchscript hell.
"""
x = self.patch_embed(x)
x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
# mix token, see token labeling for details.
if self.mix_token and self.training:
lam = np.random.beta(self.beta, self.beta)
patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
temp_x = x.clone()
sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
x = temp_x
else:
bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
# step2: tokens learning in the two stages
x = self.forward_tokens(x)
# step3: post network, apply class attention or not
if self.post_network is not None:
x = self.forward_cls(x)
x = self.norm(x)
if self.global_pool == 'avg':
x_cls = x.mean(dim=1)
elif self.global_pool == 'token':
x_cls = x[:, 0]
else:
x_cls = x
if self.aux_head is None:
return x_cls
x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling
if not self.training:
return x_cls + 0.5 * x_aux.max(1)[0]
if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
temp_x = x_aux.clone()
temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
x_aux = temp_x
x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
# return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
def forward_features(self, x):
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
# step2: tokens learning in the two stages
x = self.forward_tokens(x)
# step3: post network, apply class attention or not
if self.post_network is not None:
x = self.forward_cls(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
out = x.mean(dim=1)
elif self.global_pool == 'token':
out = x[:, 0]
else:
out = x
if pre_logits:
return out
out = self.head(out)
if self.aux_head is not None:
# generate classes in all feature tokens, see token labeling
aux = self.aux_head(x[:, 1:])
out = out + 0.5 * aux.max(1)[0]
return out
def forward(self, x):
""" simplified forward (without mix token training) """
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_volo(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg(VOLO, variant, pretrained, **kwargs)
@register_model
def volo_d1_224(pretrained=False, **kwargs):
""" VOLO-D1 model, Params: 27M """
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d1_384(pretrained=False, **kwargs):
""" VOLO-D1 model, Params: 27M """
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d2_224(pretrained=False, **kwargs):
""" VOLO-D2 model, Params: 59M """
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d2_384(pretrained=False, **kwargs):
""" VOLO-D2 model, Params: 59M """
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d3_224(pretrained=False, **kwargs):
""" VOLO-D3 model, Params: 86M """
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d3_448(pretrained=False, **kwargs):
""" VOLO-D3 model, Params: 86M """
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d4_224(pretrained=False, **kwargs):
""" VOLO-D4 model, Params: 193M """
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d4_448(pretrained=False, **kwargs):
""" VOLO-D4 model, Params: 193M """
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d5_224(pretrained=False, **kwargs):
""" VOLO-D5 model, Params: 296M
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
"""
model_args = dict(
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d5_448(pretrained=False, **kwargs):
""" VOLO-D5 model, Params: 296M
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
"""
model_args = dict(
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args)
return model
@register_model
def volo_d5_512(pretrained=False, **kwargs):
""" VOLO-D5 model, Params: 296M
stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5
"""
model_args = dict(
layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
mlp_ratio=4, stem_hidden_dim=128, **kwargs)
model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args)
return model

@ -19,7 +19,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\
create_attn, create_norm_act_layer, get_norm_act_layer
@ -178,8 +178,9 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
def __init__(
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
super(OsaBlock, self).__init__()
self.residual = residual
@ -207,10 +208,7 @@ class OsaBlock(nn.Module):
next_in_chs = in_chs + layer_per_block * mid_chs
self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
if attn:
self.attn = create_attn(attn, out_chs)
else:
self.attn = None
self.attn = create_attn(attn, out_chs) if attn else None
self.drop_path = drop_path
@ -231,10 +229,12 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module):
def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
def __init__(
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
super(OsaStage, self).__init__()
self.grad_checkpointing = False
if downsample:
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
@ -258,14 +258,18 @@ class OsaStage(nn.Module):
def forward(self, x):
if self.pool is not None:
x = self.pool(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class VovNet(nn.Module):
def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
def __init__(
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
""" VovNet (v2)
"""
super(VovNet, self).__init__()
@ -315,12 +319,23 @@ class VovNet(nn.Module):
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages.(\d+)' if coarse else r'^stages.(\d+).blocks.(\d+)',
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -331,9 +346,13 @@ class VovNet(nn.Module):
x = self.stem(x)
return self.stages(x)
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
return self.head(x)
x = self.forward_head(x)
return x
def _create_vovnet(variant, pretrained=False, **kwargs):

@ -21,7 +21,7 @@ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
import torch.jit
import torch.nn as nn
import torch.nn.functional as F
@ -172,6 +172,21 @@ class Xception(nn.Module):
m.weight.data.fill_(1)
m.bias.data.zero_()
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^conv[12]|bn[12]',
blocks=[
(r'^block(\d+)', None),
(r'^conv[34]|bn[34]', (99,)),
],
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, "gradient checkpointing not supported"
@torch.jit.ignore
def get_classifier(self):
return self.fc
@ -210,12 +225,15 @@ class Xception(nn.Module):
x = self.act4(x)
return x
def forward(self, x):
x = self.forward_features(x)
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:
F.dropout(x, self.drop_rate, training=self.training)
x = self.fc(x)
return x if pre_logits else self.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x

@ -7,11 +7,11 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .helpers import build_model_with_cfg, checkpoint_seq
from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
from .layers.helpers import to_3tuple
from .registry import register_model
@ -39,6 +39,7 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(url=''),
xception65p=_cfg(url=''),
)
@ -167,12 +168,14 @@ class XceptionAligned(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
def __init__(
self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
super(XceptionAligned, self).__init__()
assert output_stride in (8, 16, 32)
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
self.grad_checkpointing = False
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
self.stem = nn.Sequential(*[
@ -206,6 +209,18 @@ class XceptionAligned(nn.Module):
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^blocks.(\d+)',
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
@ -214,13 +229,19 @@ class XceptionAligned(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.act(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
x = self.forward_head(x)
return x
@ -307,3 +328,23 @@ def xception41p(pretrained=False, **kwargs):
]
model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d, **kwargs)
return _xception('xception41p', pretrained=pretrained, **model_args)
@register_model
def xception65p(pretrained=False, **kwargs):
""" Modified Aligned Xception-65 w/ Pre-Act
"""
block_cfg = [
# entry flow
dict(in_chs=64, out_chs=128, stride=2),
dict(in_chs=128, out_chs=256, stride=2),
dict(in_chs=256, out_chs=728, stride=2),
# middle flow
*([dict(in_chs=728, out_chs=728, stride=1)] * 16),
# exit flow
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True),
]
model_args = dict(
block_cfg=block_cfg, preact=True, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception65p', pretrained=pretrained, **model_args)

@ -16,6 +16,7 @@ from functools import partial
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
@ -215,8 +216,9 @@ class LPI(nn.Module):
class ClassAttentionBlock(nn.Module):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
super().__init__()
self.norm1 = norm_layer(dim)
@ -292,8 +294,9 @@ class XCA(nn.Module):
class XCABlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
@ -325,9 +328,10 @@ class XCiT(nn.Module):
https://github.com/facebookresearch/deit/
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
"""
Args:
img_size (int, tuple): input image size
@ -353,14 +357,17 @@ class XCiT(nn.Module):
interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
img_size = to_2tuple(img_size)
assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
'`patch_size` should divide image dimensions evenly'
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.global_pool = global_pool
self.grad_checkpointing = False
self.patch_embed = ConvPatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
@ -396,19 +403,32 @@ class XCiT(nn.Module):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=r'^blocks.(\d+)',
cls_attn_blocks=[(r'^cls_attn_blocks.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -420,24 +440,33 @@ class XCiT(nn.Module):
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x, Hp, Wp)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, Hp, Wp)
else:
x = blk(x, Hp, Wp)
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
for blk in self.cls_attn_blocks:
x = blk(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x)
else:
x = blk(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = x[:, 0]
x = self.head(x)
x = self.forward_head(x)
return x

@ -1,12 +1,16 @@
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Optional
import json
from itertools import islice
from typing import Optional, Callable, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from timm.models.helpers import group_parameters
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adahessian import Adahessian
@ -28,21 +32,122 @@ except ImportError:
has_apex = False
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
def param_groups_weight_decay(
model: nn.Module,
weight_decay=1e-5,
no_weight_decay_list=()
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
continue
if param.ndim or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
names_trunk = []
names_head = []
for n, _ in model.named_parameters():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map
def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Tuple[str] = (),
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
):
"""
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, 'group_matcher'):
# FIXME interface needs more work
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if param.ndim == 1 or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
# FIXME temporary output to debug new feature
print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
@ -56,6 +161,8 @@ def optimizer_kwargs(cfg):
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'layer_decay', None) is not None:
kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
return kwargs
@ -79,6 +186,8 @@ def create_optimizer_v2(
weight_decay: float = 0.,
momentum: float = 0.9,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None,
**kwargs):
""" Create an optimizer.
@ -101,11 +210,21 @@ def create_optimizer_v2(
"""
if isinstance(model_or_params, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
if weight_decay and filter_bias_and_bn:
skip = {}
if hasattr(model_or_params, 'no_weight_decay'):
skip = model_or_params.no_weight_decay()
parameters = add_weight_decay(model_or_params, weight_decay, skip)
no_weight_decay = {}
if hasattr(model_or_params, 'no_weight_decay'):
no_weight_decay = model_or_params.no_weight_decay()
if param_group_fn:
parameters = param_group_fn(model_or_params)
elif layer_decay is not None:
parameters = param_groups_layer_decay(
model_or_params,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay)
weight_decay = 0.
elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
weight_decay = 0.
else:
parameters = model_or_params.parameters()

@ -82,7 +82,10 @@ class Scheduler:
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
param_group[self.param_group_field] = value
if 'lr_scale' in param_group:
param_group[self.param_group_field] = value * param_group['lr_scale']
else:
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self.noise_range_t is not None:

@ -112,9 +112,17 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
help='Input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
help='Validation batch size override (default: None)')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
# Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -131,7 +139,8 @@ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--layer-decay', type=float, default=None,
help='weight decay (default: None)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
@ -188,7 +197,7 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-repeats', type=int, default=0,
parser.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
@ -276,8 +285,6 @@ parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
@ -293,10 +300,6 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
@ -386,6 +389,9 @@ def main():
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if args.local_rank == 0:
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
@ -458,7 +464,7 @@ def main():
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:

Loading…
Cancel
Save