Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu

pull/1239/head
Ross Wightman 3 years ago
commit 25d52ea71d

@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*'] 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*']
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
@ -188,23 +188,22 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
input_tensor = torch.randn((batch_size, *input_size)) input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor) outputs = model.forward_features(input_tensor)
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert outputs.shape[1] == model.num_features assert outputs.shape[1] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0) model.reset_classifier(0)
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert len(outputs.shape) == 2 assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features assert outputs.shape[1] == model.num_features
model = create_model(model_name, pretrained=False, num_classes=0).eval() model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, tuple): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
assert len(outputs.shape) == 2 assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features assert outputs.shape[1] == model.num_features

@ -319,10 +319,10 @@ def test_sgd(optimizer):
# lambda opt: ReduceLROnPlateau(opt)] # lambda opt: ReduceLROnPlateau(opt)]
# ) # )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
) )
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

@ -9,6 +9,7 @@ Hacked together by / Copyright 2020 Ross Wightman
from typing import Tuple, Optional, Union, Callable from typing import Tuple, Optional, Union, Callable
import torch.utils.data import torch.utils.data
import numpy as np
from timm.bits import DeviceEnv from timm.bits import DeviceEnv
from .collate import fast_collate from .collate import fast_collate
@ -19,6 +20,12 @@ from .mixup import FastCollateMixup
from .prefetcher_cuda import PrefetcherCuda from .prefetcher_cuda import PrefetcherCuda
def _worker_init(worker_id):
worker_info = torch.utils.data.get_worker_info()
assert worker_info.id == worker_id
np.random.seed(worker_info.seed % (2**32-1))
def create_loader_v2( def create_loader_v2(
dataset: torch.utils.data.Dataset, dataset: torch.utils.data.Dataset,
batch_size: int, batch_size: int,
@ -94,6 +101,7 @@ def create_loader_v2(
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=pin_memory, pin_memory=pin_memory,
drop_last=is_training, drop_last=is_training,
worker_init_fn=_worker_init,
persistent_workers=persistent_workers) persistent_workers=persistent_workers)
try: try:
loader = loader_class(dataset, **loader_args) loader = loader_class(dataset, **loader_args)

@ -159,7 +159,7 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle: if self.shuffle:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=worker_info.seed)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)

@ -1,4 +1,4 @@
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
from .binary_cross_entropy import DenseBinaryCrossEntropy from .binary_cross_entropy import BinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy from .jsd import JsdCrossEntropy

@ -1,23 +1,47 @@
""" Binary Cross Entropy w/ a few extras
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class DenseBinaryCrossEntropy(nn.Module): class BinaryCrossEntropy(nn.Module):
""" BCE using one-hot from dense targets w/ label smoothing """ BCE with optional one-hot from dense targets, label smoothing, thresholding
NOTE for experiments comparing CE to BCE /w label smoothing, may remove NOTE for experiments comparing CE to BCE /w label smoothing, may remove
""" """
def __init__(self, smoothing=0.1): def __init__(
super(DenseBinaryCrossEntropy, self).__init__() self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
super(BinaryCrossEntropy, self).__init__()
assert 0. <= smoothing < 1.0 assert 0. <= smoothing < 1.0
self.smoothing = smoothing self.smoothing = smoothing
self.bce = nn.BCEWithLogitsLoss() self.target_threshold = target_threshold
self.reduction = reduction
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
def forward(self, x, target): def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
num_classes = x.shape[-1] assert x.shape[0] == target.shape[0]
off_value = self.smoothing / num_classes if target.shape != x.shape:
on_value = 1. - self.smoothing + off_value # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
target = target.long().view(-1, 1) num_classes = x.shape[-1]
target = torch.full( # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
(target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value) off_value = self.smoothing / num_classes
return self.bce(x, target) on_value = 1. - self.smoothing + off_value
target = target.long().view(-1, 1)
target = torch.full(
(target.size()[0], num_classes),
off_value,
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
if self.target_threshold is not None:
# Make target 0, or 1 if threshold set
target = target.gt(self.target_threshold).to(dtype=target.dtype)
return F.binary_cross_entropy_with_logits(
x, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)

@ -1,23 +1,23 @@
""" Cross Entropy w/ smoothing or soft targets
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module): class LabelSmoothingCrossEntropy(nn.Module):
""" """ NLL loss with label smoothing.
NLL loss with label smoothing.
""" """
def __init__(self, smoothing=0.1): def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__() super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0 assert smoothing < 1.0
self.smoothing = smoothing self.smoothing = smoothing
self.confidence = 1. - smoothing self.confidence = 1. - smoothing
def forward(self, x, target): def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
logprobs = F.log_softmax(x, dim=-1) logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1) nll_loss = nll_loss.squeeze(1)
@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module):
def __init__(self): def __init__(self):
super(SoftTargetCrossEntropy, self).__init__() super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target): def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean() return loss.mean()

@ -1,8 +1,10 @@
from .beit import *
from .byoanet import * from .byoanet import *
from .byobnet import * from .byobnet import *
from .cait import * from .cait import *
from .coat import * from .coat import *
from .convit import * from .convit import *
from .crossvit import *
from .cspnet import * from .cspnet import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *
@ -36,6 +38,7 @@ from .sknet import *
from .swin_transformer import * from .swin_transformer import *
from .tnt import * from .tnt import *
from .tresnet import * from .tresnet import *
from .twins import *
from .vgg import * from .vgg import *
from .visformer import * from .visformer import *
from .vision_transformer import * from .vision_transformer import *
@ -44,7 +47,6 @@ from .vovnet import *
from .xception import * from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .xcit import * from .xcit import *
from .twins import *
from .factory import create_model, split_model_name, safe_model_name from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters

@ -0,0 +1,420 @@
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure.
Modifications by / Copyright 2021 Ross Wightman, original copyrights below
"""
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm and DeiT code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'beit_base_patch16_224': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
'beit_base_patch16_384': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_base_patch16_224_in22k': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
'beit_large_patch16_224': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
'beit_large_patch16_384': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_large_patch16_512': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
input_size=(3, 512, 512), crop_pct=1.0,
),
'beit_large_patch16_224_in22k': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841,
),
}
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
if torch.jit.is_scripting():
# FIXME requires_grad breaks w/ torchscript
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
else:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class Beit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
for i in range(depth)])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
self.fix_init_weight()
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
x = self.norm(x)
if self.fc_norm is not None:
t = x[:, 1:, :]
return self.fc_norm(t.mean(1))
else:
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.')
model = build_model_with_cfg(
Beit, variant, pretrained,
default_cfg=default_cfg,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def beit_base_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_base_patch16_384(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_large_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_large_patch16_384(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_large_patch16_512(pretrained=False, **kwargs):
model_kwargs = dict(
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model

@ -36,22 +36,22 @@ default_cfgs = {
'botnet26t_256': _cfg( 'botnet26t_256': _cfg(
url='', url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50t_256': _cfg( 'botnet50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet50t_256-a0e6c3b1.pth', url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256': _cfg( 'eca_botnext26ts_256': _cfg(
url='', url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_256-fb3bf984.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg( 'halonet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'sehalonet33ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'sehalonet33ts': _cfg(
'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'eca_halonext26ts': _cfg( 'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
@ -78,16 +78,17 @@ model_cfgs = dict(
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
self_attn_kwargs=dict() self_attn_kwargs=dict()
), ),
botnet50t=ByoModelCfg( botnet50ts=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
), ),
stem_chs=64, stem_chs=64,
stem_type='tiered', stem_type='tiered',
stem_pool='maxpool', stem_pool='maxpool',
act_layer='silu',
fixed_input_size=True, fixed_input_size=True,
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
self_attn_kwargs=dict() self_attn_kwargs=dict()
@ -108,22 +109,6 @@ model_cfgs = dict(
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
self_attn_kwargs=dict() self_attn_kwargs=dict()
), ),
eca_botnext50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
),
halonet_h1=ByoModelCfg( halonet_h1=ByoModelCfg(
blocks=( blocks=(
@ -227,38 +212,31 @@ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
@register_model @register_model
def botnet26t_256(pretrained=False, **kwargs): def botnet26t_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final two stages. """ Bottleneck Transformer w/ ResNet26-T backbone.
FIXME 26t variant was mixed up with 50t arch cfg, retraining and determining why so low NOTE: this isn't performing well, may remove
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@register_model @register_model
def botnet50t_256(pretrained=False, **kwargs): def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final two stages. """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.
NOTE: this isn't performing well, may remove
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('botnet50t_256', 'botnet50t', pretrained=pretrained, **kwargs) return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def eca_botnext26ts_256(pretrained=False, **kwargs): def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages. """ Bottleneck Transformer w/ ResNet26-T backbone, silu act.
FIXME 26ts variant was mixed up with 50ts arch cfg, retraining and determining why so low NOTE: this isn't performing well, may remove
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@register_model
def eca_botnext50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act, Bottleneck attn in final two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext50ts_256', 'eca_botnext50ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def halonet_h1(pretrained=False, **kwargs): def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper. """ HaloNet-H1. Halo attention in all stages as per the paper.

@ -98,7 +98,7 @@ default_cfgs = {
test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'), test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'),
'resnext26ts': _cfg( 'resnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256-df727fca.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnext26ts': _cfg( 'gcresnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
@ -118,7 +118,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'resnet33ts': _cfg( 'resnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet33ts': _cfg( 'gcresnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
@ -137,6 +137,17 @@ default_cfgs = {
'gcresnext50ts': _cfg( 'gcresnext50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
# experimental models
'regnetz_b': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'regnetz_c': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'regnetz_d': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
} }
@ -489,6 +500,51 @@ model_cfgs = dict(
act_layer='silu', act_layer='silu',
attn_layer='gca', attn_layer='gca',
), ),
# experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
regnetz_b=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=192, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=6, c=384, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=24, br=0.25, block_kwargs=dict(linear_out=True)),
),
stem_chs=32,
stem_pool='',
num_features=1792,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
),
regnetz_c=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=128, s=2, gs=16, br=0.5, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=32, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
),
stem_chs=32,
stem_pool='',
num_features=1792,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
),
regnetz_d=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=6, c=512, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=12, c=768, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=64, br=0.25, block_kwargs=dict(linear_out=True)),
),
stem_chs=128,
stem_type='quad',
stem_pool='',
num_features=1792,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
),
) )
@ -678,6 +734,27 @@ def gcresnext50ts(pretrained=False, **kwargs):
return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs) return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
@register_model
def regnetz_b(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs)
@register_model
def regnetz_c(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs)
@register_model
def regnetz_d(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence): if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,) stage_blocks_cfg = (stage_blocks_cfg,)

@ -0,0 +1,497 @@
""" CrossViT Model
@inproceedings{
chen2021crossvit,
title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
booktitle={International Conference on Computer Vision (ICCV)},
year={2021}
}
Paper link: https://arxiv.org/abs/2103.14899
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
"""
# Copyright IBM All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.hub
from functools import partial
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .vision_transformer import Mlp, Block
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
'classifier': ('head.0', 'head.1'),
**kwargs
}
default_cfgs = {
'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
'crossvit_15_dagger_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_15_dagger_408': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
),
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
'crossvit_18_dagger_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_18_dagger_408': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
),
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
'crossvit_9_dagger_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_base_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
'crossvit_small_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
'crossvit_tiny_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
}
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
if multi_conv:
if patch_size[0] == 12:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
)
elif patch_size[0] == 16:
self.proj = nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
)
else:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# B1C -> B1H(C/H) -> BH1(C/H)
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# BNC -> BNH(C/H) -> BHN(C/H)
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# BNC -> BNH(C/H) -> BHN(C/H)
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttentionBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
return x
class MultiScaleBlock(nn.Module):
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
num_branches = len(dim)
self.num_branches = num_branches
# different branch could have different embedding size, the first one is the base
self.blocks = nn.ModuleList()
for d in range(num_branches):
tmp = []
for i in range(depth[d]):
tmp.append(Block(
dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
if len(tmp) != 0:
self.blocks.append(nn.Sequential(*tmp))
if len(self.blocks) == 0:
self.blocks = None
self.projs = nn.ModuleList()
for d in range(num_branches):
if dim[d] == dim[(d + 1) % num_branches] and False:
tmp = [nn.Identity()]
else:
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
self.projs.append(nn.Sequential(*tmp))
self.fusion = nn.ModuleList()
for d in range(num_branches):
d_ = (d + 1) % num_branches
nh = num_heads[d_]
if depth[-1] == 0: # backward capability:
self.fusion.append(
CrossAttentionBlock(
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
else:
tmp = []
for _ in range(depth[-1]):
tmp.append(CrossAttentionBlock(
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
self.fusion.append(nn.Sequential(*tmp))
self.revert_projs = nn.ModuleList()
for d in range(num_branches):
if dim[(d + 1) % num_branches] == dim[d] and False:
tmp = [nn.Identity()]
else:
tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
nn.Linear(dim[(d + 1) % num_branches], dim[d])]
self.revert_projs.append(nn.Sequential(*tmp))
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
outs_b = []
for i, block in enumerate(self.blocks):
outs_b.append(block(x[i]))
# only take the cls token out
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
for i, proj in enumerate(self.projs):
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
# cross attention
outs = []
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
tmp = fusion(tmp)
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
outs.append(tmp)
return outs
def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False,
):
super().__init__()
self.num_classes = num_classes
self.img_size = to_2tuple(img_size)
img_scale = to_2tuple(img_scale)
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
self.crop_scale = crop_scale # crop instead of interpolate for scale
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
self.num_branches = len(patch_size)
self.embed_dim = embed_dim
self.num_features = embed_dim[0] # to pass the tests
self.patch_embed = nn.ModuleList()
# hard-coded for torch jit script
for i in range(self.num_branches):
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
self.patch_embed.append(
PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
self.pos_drop = nn.Dropout(p=drop_rate)
total_depth = sum([sum(x[-2:]) for x in depth])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
dpr_ptr = 0
self.blocks = nn.ModuleList()
for idx, block_cfg in enumerate(depth):
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
blk = MultiScaleBlock(
embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
dpr_ptr += curr_depth
self.blocks.append(blk)
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
self.head = nn.ModuleList([
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
for i in range(self.num_branches)])
for i in range(self.num_branches):
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
out = set()
for i in range(self.num_branches):
out.add(f'cls_token_{i}')
pe = getattr(self, f'pos_embed_{i}', None)
if pe is not None and pe.requires_grad:
out.add(f'pos_embed_{i}')
return out
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.ModuleList(
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])
def forward_features(self, x):
B, C, H, W = x.shape
xs = []
for i, patch_embed in enumerate(self.patch_embed):
x_ = x
ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]:
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1)
x_ = torch.cat((cls_tokens, x_), dim=1)
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
x_ = x_ + pos_embed
x_ = self.pos_drop(x_)
xs.append(x_)
for i, blk in enumerate(self.blocks):
xs = blk(xs)
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
return [xo[:, 0] for xo in xs]
def forward(self, x):
xs = self.forward_features(x)
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
if not isinstance(self.head[0], nn.Identity):
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits
def _create_crossvit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
def pretrained_filter_fn(state_dict):
new_state_dict = {}
for key in state_dict.keys():
if 'pos_embed' in key or 'cls_token' in key:
new_key = key.replace(".", "_")
else:
new_key = key
new_state_dict[new_key] = state_dict[key]
return new_state_dict
return build_model_with_cfg(
CrossViT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=pretrained_filter_fn,
**kwargs)
@register_model
def crossvit_tiny_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_small_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_base_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_9_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_15_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_18_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_9_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_15_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_15_dagger_408(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_18_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
return model
@register_model
def crossvit_18_dagger_408(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
return model

@ -600,7 +600,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
stem_size=32, stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -636,7 +636,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
stem_size=32, stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -665,7 +665,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
stem_size=8, stem_size=8,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -694,7 +694,7 @@ def _gen_mobilenet_v2(
stem_size=32, stem_size=32,
fix_stem=fix_stem_head, fix_stem=fix_stem_head,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'relu6'), act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs **kwargs
) )
@ -725,7 +725,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
stem_size=16, stem_size=16,
num_features=1984, # paper suggests this, but is not 100% clear num_features=1984, # paper suggests this, but is not 100% clear
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -760,7 +760,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
block_args=decode_arch_def(arch_def), block_args=decode_arch_def(arch_def),
stem_size=32, stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -807,7 +807,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
act_layer=resolve_act_layer(kwargs, 'swish'), act_layer=resolve_act_layer(kwargs, 'swish'),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs, **kwargs,
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -836,7 +836,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
num_features=round_chs_fn(1280), num_features=round_chs_fn(1280),
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'relu'), act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs, **kwargs,
) )
@ -867,7 +867,7 @@ def _gen_efficientnet_condconv(
num_features=round_chs_fn(1280), num_features=round_chs_fn(1280),
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'swish'), act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs, **kwargs,
) )
@ -909,7 +909,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
fix_stem=True, fix_stem=True,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
act_layer=resolve_act_layer(kwargs, 'relu6'), act_layer=resolve_act_layer(kwargs, 'relu6'),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs, **kwargs,
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -937,7 +937,7 @@ def _gen_efficientnetv2_base(
num_features=round_chs_fn(1280), num_features=round_chs_fn(1280),
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'), act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs, **kwargs,
) )
@ -976,7 +976,7 @@ def _gen_efficientnetv2_s(
num_features=round_chs_fn(num_features), num_features=round_chs_fn(num_features),
stem_size=24, stem_size=24,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'), act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs, **kwargs,
) )
@ -1006,7 +1006,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
num_features=1280, num_features=1280,
stem_size=24, stem_size=24,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'), act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs, **kwargs,
) )
@ -1036,7 +1036,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
num_features=1280, num_features=1280,
stem_size=32, stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'), act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs, **kwargs,
) )
@ -1066,7 +1066,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
num_features=1280, num_features=1280,
stem_size=32, stem_size=32,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'silu'), act_layer=resolve_act_layer(kwargs, 'silu'),
**kwargs, **kwargs,
) )
@ -1100,7 +1100,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
num_features=1536, num_features=1536,
stem_size=16, stem_size=16,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)
@ -1133,7 +1133,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
num_features=1536, num_features=1536,
stem_size=24, stem_size=24,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier), round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
**kwargs **kwargs
) )
model = _create_effnet(variant, pretrained, **model_kwargs) model = _create_effnet(variant, pretrained, **model_kwargs)

@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x return x
@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d):
if self.same_pad: if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x return x
@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1), weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d):
if self.same_pad: if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1), weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

@ -344,7 +344,7 @@ class ResNetV2(nn.Module):
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last=True): drop_rate=0., drop_path_rate=0., zero_init_last=False):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate

@ -683,7 +683,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_base_patch16_sam_224(pretrained=False, **kwargs): def vit_base_patch16_sam_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
""" """
# NOTE original SAM weights releaes worked with representation_size=768 # NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
return model return model
@ -693,7 +693,7 @@ def vit_base_patch16_sam_224(pretrained=False, **kwargs):
def vit_base_patch32_sam_224(pretrained=False, **kwargs): def vit_base_patch32_sam_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
""" """
# NOTE original SAM weights releaes worked with representation_size=768 # NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
return model return model

@ -86,10 +86,10 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset') help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', parser.add_argument('-b', '--batch-size', type=int, default=256, metavar='N',
help='input batch size for training (default: 32)') help='input batch size for training (default: 32)')
parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='ratio of validation batch size to training batch size (default: 1)') help='validation batch size override (default: None)')
# Optimizer parameters # Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -109,10 +109,10 @@ parser.add_argument('--clip-mode', type=str, default='norm',
# Learning rate schedule parameters # Learning rate schedule parameters
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"') help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.01)') help='learning rate (default: 0.05)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages') help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
@ -131,15 +131,15 @@ parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)') help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 2)') help='number of epochs to train (default: 300)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
@ -169,10 +169,12 @@ parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False, parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.') help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "const")') help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1, parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)') help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False, parser.add_argument('--resplit', action='store_true', default=False,
@ -213,7 +215,7 @@ parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true', parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='', parser.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true', parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.') help='Enable separate BN layers per augmentation split.')
@ -460,12 +462,12 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
elif mixup_active: elif mixup_active:
# smoothing is handled with mixup target transform # smoothing is handled with mixup target transform
if args.bce_loss: if args.bce_loss:
train_loss_fn = nn.BCEWithLogitsLoss() train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
else: else:
train_loss_fn = SoftTargetCrossEntropy() train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing: elif args.smoothing:
if args.bce_loss: if args.bce_loss:
train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing) train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
else: else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else: else:
@ -583,7 +585,7 @@ def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
eval_workers = min(2, args.workers) eval_workers = min(2, args.workers)
loader_eval = create_loader_v2( loader_eval = create_loader_v2(
dataset_eval, dataset_eval,
batch_size=args.validation_batch_size_multiplier * args.batch_size, batch_size=args.validation_batch_size or args.batch_size,
is_training=False, is_training=False,
normalize=not normalize_in_transform, normalize=not normalize_in_transform,
pp_cfg=eval_pp_cfg, pp_cfg=eval_pp_cfg,

@ -249,6 +249,11 @@ def main():
model_names = list_models(args.model) model_names = list_models(args.model)
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs): if len(model_cfgs):
results_file = args.results_file or './results-all.csv' results_file = args.results_file or './results-all.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))

Loading…
Cancel
Save