Merge pull request #1593 from rwightman/multi-weight_effnet_convnext

Update efficientnet.py and convnext.py to multi-weight, add new 12k pretrained weights
pull/1612/head
Ross Wightman 2 years ago committed by GitHub
commit 4e24f75289
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

10
.gitignore vendored

@ -106,6 +106,16 @@ output/
*.tar
*.pth
*.pt
*.torch
*.gz
Untitled.ipynb
Testing notebook.ipynb
# Root dir exclusions
/*.csv
/*.yaml
/*.json
/*.jpg
/*.png
/*.zip
/*.tar.*

@ -27,7 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

@ -151,7 +151,7 @@ def create_dataset(
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, **kwargs)
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,

@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
name = name.split('/', 1)
prefix = ''
if len(name) > 1:
prefix = name[0]

@ -13,13 +13,14 @@ try:
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader
def get_class_labels(info):
def get_class_labels(info, label_key='label'):
if 'label' not in info.features:
return {}
class_label = info.features['label']
class_label = info.features[label_key]
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
return class_to_idx
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
name,
split='train',
class_map=None,
label_key='label',
download=False,
):
"""
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
#use_auth_token=True,
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
self.class_to_idx = get_class_labels(self.dataset.info)
self.label_key = label_key
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
else:
assert 'path' in image and image['path']
image = open(image['path'], 'rb')
return image, item['label']
label = item[self.label_key]
if self.remap_class:
label = self.class_to_idx[label]
return image, label
def __len__(self):
return len(self.dataset)

@ -1,6 +1,7 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
@ -30,8 +31,12 @@ from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
FourierEmbed, RotaryEmbedding
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct

@ -13,7 +13,7 @@ import torch
import torch.nn as nn
from .helpers import to_2tuple
from .pos_embed import apply_rot_embed, RotaryEmbedding
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_

@ -10,7 +10,7 @@ import collections.abc
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(x)
return tuple(repeat(x, n))
return parse

@ -2,15 +2,24 @@
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Based on code in:
* https://github.com/google-research/vision_transformer
* https://github.com/google-research/big_vision/tree/main/big_vision
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List
import torch
from torch import nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .trace_utils import _assert
_logger = logging.getLogger(__name__)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
@ -46,3 +55,122 @@ class PatchEmbed(nn.Module):
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
def resample_patch_embed(
patch_embed,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.
Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
Args:
patch_embed: original parameter to be resized.
new_size (tuple(int, int): target shape (height, width)-only.
interpolation (str): interpolation for resize
antialias (bool): use anti-aliasing filter in resize
verbose (bool): log operation
Returns:
Resized patch embedding kernel.
"""
import numpy as np
assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw"
old_size = patch_embed.shape[-2:]
if tuple(old_size) == tuple(new_size):
return patch_embed
if verbose:
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
def resize(x_np, _new_size):
x_tf = torch.Tensor(x_np)[None, None, ...]
x_upsampled = F.interpolate(
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
return x_upsampled
def get_resize_mat(_old_size, _new_size):
mat = []
for i in range(np.prod(_old_size)):
basis_vec = np.zeros(_old_size)
basis_vec[np.unravel_index(i, _old_size)] = 1.
mat.append(resize(basis_vec, _new_size).reshape(-1))
return np.stack(mat).T
resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size)
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
return v_resample_kernel(patch_embed)
# def divs(n, m=None):
# m = m or n // 2
# if m == 1:
# return [1]
# if n % m == 0:
# return [m] + divs(n, m - 1)
# return divs(n, m - 1)
#
#
# class FlexiPatchEmbed(nn.Module):
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
# FIXME WIP
# """
# def __init__(
# self,
# img_size=240,
# patch_size=16,
# in_chans=3,
# embed_dim=768,
# base_img_size=240,
# base_patch_size=32,
# norm_layer=None,
# flatten=True,
# bias=True,
# ):
# super().__init__()
# self.img_size = to_2tuple(img_size)
# self.patch_size = to_2tuple(patch_size)
# self.num_patches = 0
#
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
#
# self.base_img_size = to_2tuple(base_img_size)
# self.base_patch_size = to_2tuple(base_patch_size)
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
#
# self.flatten = flatten
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
#
# def forward(self, x):
# B, C, H, W = x.shape
#
# if self.patch_size == self.base_patch_size:
# weight = self.proj.weight
# else:
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
# patch_size = self.patch_size
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# x = self.norm(x)
# return x

@ -1,207 +1,52 @@
""" Position Embedding Utilities
Hacked together by / Copyright 2022 Ross Wightman
"""
import logging
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def inv_freq_bands(
num_bands: int,
temperature: float = 100000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
import torch.nn.functional as F
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
from .helpers import to_2tuple
_logger = logging.getLogger(__name__)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
linear_bands: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
concat_out=False, device=device, dtype=dtype)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
return posemb
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
# do the interpolation
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb

@ -0,0 +1,283 @@
""" Relative position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import Mlp
from .weight_init import trunc_normal_
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index.contiguous()
class RelPosBias(nn.Module):
""" Relative Position Bias
Adapted from Swin-V1 relative position bias impl, modularized.
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
class RelPosMlp(nn.Module):
""" Log-Coordinate Relative Position MLP
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
"""
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
):
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
Returns:
a lookup Tensor of size [length, length, vocab_size] that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
# Return the cached lookup tensor, otherwise compute it and cache it.
vocab_size = 2 * max_relative_position + 1
ret = torch.zeros(length, length, vocab_size)
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
continue
ret[i, x, v] = 1
return ret
def reindex_2d_einsum_lookup(
relative_position_tensor,
height: int,
width: int,
height_lookup: torch.Tensor,
width_lookup: torch.Tensor,
) -> torch.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
height_lookup: one-hot height lookup
width_lookup: one-hot width lookup
Returns:
reindexed_tensor: a Tensor of shape
[..., height * width, height * width, ...]
"""
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
area = height * width
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
class RelPosBiasTf(nn.Module):
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.num_heads = num_heads
vocab_height = 2 * window_size[0] - 1
vocab_width = 2 * window_size[1] - 1
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
# FIXME change to not use one-hot/einsum?
return reindex_2d_einsum_lookup(
self.relative_position_bias_table,
self.window_size[0],
self.window_size[1],
self.height_lookup,
self.width_lookup
)
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()

@ -0,0 +1,219 @@
""" Sin-cos, fourier, rotary position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def inv_freq_bands(
num_bands: int,
temperature: float = 100000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
linear_bands: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 4,
max_res=max_freq,
linear_bands=linear_bands,
concat_out=False,
device=device,
dtype=dtype,
)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)

@ -1,5 +1,6 @@
import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
else:
# default source == timm or unspecified
if pretrained_file:
# file load override is the highest priority if set
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
else:
# next, HF hub is prioritized unless a valid cached version of weights exists already
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
@ -105,7 +112,7 @@ def load_custom_pretrained(
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
progress=_DOWNLOAD_PROGRESS,
)
if load_fn is not None:
@ -146,12 +153,21 @@ def load_pretrained(
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
if pretrained_cfg.get('custom_load', False):
pretrained_loc = download_cached_file(
pretrained_loc,
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
model.load_pretrained(pretrained_loc)
return
else:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
@ -364,20 +380,14 @@ def build_model_with_cfg(
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:

@ -1,3 +1,4 @@
import hashlib
import json
import logging
import os
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file
def check_cached_file(url, check_hash=True):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
if hash_prefix:
with open(cached_file, 'rb') as f:
hd = hashlib.sha256(f.read()).hexdigest()
if hd[:len(hash_prefix)] != hash_prefix:
return False
return True
return False
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
@ -90,14 +111,14 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
return json.loads(text)
def _download_from_hf(model_id: str, filename: str):
def download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
cached_file = download_from_hf(model_id, 'config.json')
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
@ -124,34 +145,28 @@ def load_model_config_from_hf(model_id: str):
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, filename)
cached_file = download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
def save_config_for_hf(model, config_path, model_config=None):
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
# set some values at root config level
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
if isinstance(global_pool_type, str) and global_pool_type:
hf_config['global_pool'] = global_pool_type
if 'label' in model_config:
if 'labels' in model_config:
_logger.warning(
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"Using provided 'label' field as 'label_name'.")
model_config['label_name'] = model_config.pop('label')
model_config['label_name'] = model_config.pop('labels')
label_name = model_config.pop('label_name', None)
if label_name:
@ -173,6 +188,18 @@ def save_for_hf(model, save_directory, model_config=None):
json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config)
def push_to_hf_hub(
model,
repo_id: str,

@ -19,6 +19,7 @@ class PretrainedCfg:
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
@ -44,9 +45,11 @@ class PretrainedCfg:
classifier: Optional[str] = None
license: Optional[str] = None
source_url: Optional[str] = None
paper: Optional[str] = None
notes: Optional[str] = None
description: Optional[str] = None
origin_url: Optional[str] = None
paper_name: Optional[str] = None
paper_ids: Optional[Union[str, Tuple[str]]] = None
notes: Optional[Tuple[str]] = None
@property
def has_weights(self):
@ -62,11 +65,11 @@ class PretrainedCfg:
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None and k not in keep_none:
if remove_null and v is None and k not in keep_null:
continue
filtered_cfg[k] = v
return filtered_cfg

@ -7,6 +7,7 @@ import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
@ -48,24 +49,31 @@ def register_model(fn):
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name]
if not isinstance(cfg, DefaultCfg):
default_cfg = mod.default_cfgs[model_name]
if not isinstance(default_cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict)
assert isinstance(default_cfg, dict)
# old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
pretrained_cfg = PretrainedCfg(**default_cfg)
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
for tag_idx, tag in enumerate(cfg.tags):
for tag_idx, tag in enumerate(default_cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
pretrained_cfg = default_cfg.cfgs[tag]
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
replace_items = dict(architecture=model_name, tag=tag if tag else None)
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
# auto-complete hub name w/ architecture.tag
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
pretrained_cfg = replace(pretrained_cfg, **replace_items)
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
@ -74,7 +82,7 @@ def register_model(fn):
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
_model_default_cfgs[model_name] = default_cfg
return fn
@ -198,15 +206,21 @@ def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name):
def get_pretrained_cfg(model_name, allow_unregistered=True):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
arch_name, tag = split_model_name_tag(model_name)
if arch_name in _model_default_cfgs:
# if model arch exists, but the tag is wrong, error out
raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
if allow_unregistered:
# if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
return None
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
if model_name in _model_pretrained_cfgs:
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
raise RuntimeError(f'No pretrained config exist for model {model_name}.')
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
return getattr(cfg, cfg_key, None)

@ -355,64 +355,76 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
hf_hub_id='timm/'),
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
hf_hub_id='timm/',
num_classes=21841,
),
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
hf_hub_id='timm/'),
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0,
),
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
hf_hub_id='timm/',
num_classes=21841,
),
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
})

@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
@ -375,90 +374,131 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_atto.timm_in1k': _cfg(
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_ols.timm_in1k': _cfg(
'convnext_atto_ols.a2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto.timm_in1k': _cfg(
'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto_ols.timm_in1k': _cfg(
'convnext_femto_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico.timm_in1k': _cfg(
'convnext_pico.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico_ols.timm_in1k': _cfg(
'convnext_pico_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.timm_in1k': _cfg(
'convnext_nano.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano_ols.timm_in1k': _cfg(
'convnext_nano_ols.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny_hnf.timm_in1k': _cfg(
'convnext_tiny_hnf.a2h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_xxlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_tiny_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
'convnext_small_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
'convnext_base_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
'convnext_large_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
'convnext_xlarge_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
'convnext_tiny.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_small.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_base.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_large.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_xlarge.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
})
@ -576,3 +616,10 @@ def convnext_xlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_xxlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
return model

File diff suppressed because it is too large Load Diff

@ -2,6 +2,7 @@
from timm.layers.activations import *
from timm.layers.adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from timm.layers.blur_pool import BlurPool2d
from timm.layers.classifier import ClassifierHead, create_classifier
from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer

@ -47,16 +47,15 @@ import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm
from timm.layers import SelectAdaptivePool2d, create_pool2d
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
@ -1076,93 +1075,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
return cfg
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
):
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
Returns:
a lookup Tensor of size [length, length, vocab_size] that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
# Return the cached lookup tensor, otherwise compute it and cache it.
vocab_size = 2 * max_relative_position + 1
ret = torch.zeros(length, length, vocab_size)
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
continue
ret[i, x, v] = 1
return ret
def reindex_2d_einsum_lookup(
relative_position_tensor,
height: int,
width: int,
height_lookup: torch.Tensor,
width_lookup: torch.Tensor,
) -> torch.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
height_lookup: one-hot height lookup
width_lookup: one-hot width lookup
Returns:
reindexed_tensor: a Tensor of shape
[..., height * width, height * width, ...]
"""
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
area = height * width
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
class RelPosBiasTf(nn.Module):
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.num_heads = num_heads
vocab_height = 2 * window_size[0] - 1
vocab_width = 2 * window_size[1] - 1
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
# FIXME change to not use one-hot/einsum?
return reindex_2d_einsum_lookup(
self.relative_position_bias_table,
self.window_size[0],
self.window_size[1],
self.height_lookup,
self.width_lookup
)
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class NormMlpHead(nn.Module):
def __init__(

@ -21,93 +21,12 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['MobileNetV3', 'MobileNetV3Features']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'mobilenetv3_large_075': _cfg(url=''),
'mobilenetv3_large_100': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
'mobilenetv3_large_100_miil': _cfg(
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'),
'mobilenetv3_large_100_miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
'mobilenetv3_small_050': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
interpolation='bicubic'),
'mobilenetv3_small_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
interpolation='bicubic'),
'mobilenetv3_small_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
interpolation='bicubic'),
'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_100': _cfg(
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_g': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
"lcnet_035": _cfg(),
"lcnet_050": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
interpolation='bicubic',
),
"lcnet_075": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
interpolation='bicubic',
),
"lcnet_100": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
interpolation='bicubic',
),
"lcnet_150": _cfg(),
}
class MobileNetV3(nn.Module):
""" MobiletNet-V3
@ -124,9 +43,24 @@ class MobileNetV3(nn.Module):
"""
def __init__(
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
self,
block_args,
num_classes=1000,
in_chans=3,
stem_size=16,
fix_stem=False,
num_features=1280,
head_bias=True,
pad_type='',
act_layer=None,
norm_layer=None,
se_layer=None,
se_from_exp=True,
round_chs_fn=round_channels,
drop_rate=0.,
drop_path_rate=0.,
global_pool='avg',
):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -145,8 +79,15 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
output_stride=32,
pad_type=pad_type,
round_chs_fn=round_chs_fn,
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
head_chs = builder.in_chs
@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module):
"""
def __init__(
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
self,
block_args,
out_indices=(0, 1, 2, 3, 4),
feature_location='bottleneck',
in_chans=3,
stem_size=16,
fix_stem=False,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
se_from_exp=True,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_rate=0.,
drop_path_rate=0.,
):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -243,9 +198,16 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
drop_path_rate=drop_path_rate, feature_location=feature_location)
output_stride=output_stride,
pad_type=pad_type,
round_chs_fn=round_chs_fn,
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
feature_location=feature_location,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
model = build_model_with_cfg(
model_cls, variant, pretrained,
model_cls,
variant,
pretrained,
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**kwargs)
@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = generate_default_cfgs({
'mobilenetv3_large_075.untrained': _cfg(url=''),
'mobilenetv3_large_100.ra_in1k': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
hf_hub_id='timm/'),
'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg(
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
paper_ids='arXiv:2104.10972v4',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth',
hf_hub_id='timm/'),
'mobilenetv3_large_100.miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
hf_hub_id='timm/',
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
paper_ids='arXiv:2104.10972v4',
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
'mobilenetv3_small_050.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_small_075.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_small_100.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_rw.rmsp_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_minimal_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_075.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_100.in1k': _cfg(
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_minimal_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
hf_hub_id='timm/',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_d.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
hf_hub_id='timm/',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_g.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
"lcnet_035.untrained": _cfg(),
"lcnet_050.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_075.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_100.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_150.untrained": _cfg(),
})
@register_model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
return model
@register_model
def mobilenetv3_large_100_miil(pretrained=False, **kwargs):
""" MobileNet V3
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):
""" MobileNet V3, 21k pretraining
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_small_050(pretrained=False, **kwargs):
""" MobileNet V3 """

@ -8,14 +8,18 @@ A PyTorch implement of Vision Transformers as described in:
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer
`FlexiViT: One Model for All Patch Sizes`
- https://arxiv.org/abs/2212.08013
The official jax code is released and available at
* https://github.com/google-research/vision_transformer
* https://github.com/google-research/big_vision
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020, Ross Wightman
"""
@ -23,7 +27,7 @@ import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Optional
from typing import Optional, List
import torch
import torch.nn as nn
@ -32,7 +36,8 @@ import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs
@ -449,6 +454,39 @@ def get_init_weights_vit(mode='jax', head_bias: float = 0.):
return init_weights_vit_timm
def resize_pos_embed(
posemb,
posemb_new,
num_prefix_tokens=1,
gs_new=(),
interpolation='bicubic',
antialias=False,
):
""" Rescale the grid of position embeddings when loading from state_dict.
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
Adapted from:
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
"""
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
@ -468,8 +506,15 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
interpolation = 'bilinear'
antialias = False
big_vision = False
if not prefix:
if 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
elif 'params/embedding/kernel' in w:
prefix = 'params/'
big_vision = True
if hasattr(model.patch_embed, 'backbone'):
# hybrid
@ -495,17 +540,33 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
model.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if model.cls_token is not None:
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if big_vision:
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
else:
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
old_shape = pos_embed_w.shape
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
@ -517,9 +578,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
@ -529,32 +591,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
def _convert_openai_clip(state_dict, model):
@ -591,7 +631,13 @@ def _convert_openai_clip(state_dict, model):
return out_dict
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
def checkpoint_filter_fn(
state_dict,
model,
adapt_layer_scale=False,
interpolation='bicubic',
antialias=True,
):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
@ -603,17 +649,30 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
return _convert_openai_clip(state_dict, model)
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
if 'patch_embed.proj.weight' in k:
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
if len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
if v.shape[-1] != W or v.shape[-2] != H:
v = resample_patch_embed(
v,
(H, W),
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif adapt_layer_scale and 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
@ -641,67 +700,101 @@ default_cfgs = generate_default_cfgs({
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
file='b16_augreg-a-8.pth'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(
url=''),
hf_hub_id='timm/'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
url=''),
hf_hub_id='timm/'),
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
hf_hub_id='timm/'),
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
# How to train your ViT (augreg) weights trained on in1k
# How to train your ViT (augreg) weights trained on in1k only
'vit_small_patch16_224.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_small_patch16_384.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_base_patch32_384.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True),
'vit_base_patch16_384.augreg_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch14_224.untrained': _cfg(url=''),
@ -711,77 +804,94 @@ default_cfgs = generate_default_cfgs({
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843),
'vit_huge_patch14_224.v1_in21k': _cfg(
'vit_large_patch32_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
hf_hub_id='timm/',
num_classes=21843),
'vit_huge_patch14_224.orig_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub_id='timm/vit_huge_patch14_224_in21k',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
# How to train your ViT (augreg) weights, pretrained on in21k
'vit_tiny_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_small_patch32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_small_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_base_patch32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_base_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_base_patch8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
'vit_large_patch16_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
hf_hub_id='timm/',
custom_load=True, num_classes=21843),
# SAM trained models (https://arxiv.org/abs/2106.01548)
'vit_base_patch32_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True),
url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
hf_hub_id='timm/'),
'vit_base_patch16_224.sam': _cfg(
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True),
url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
hf_hub_id='timm/'),
# DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
'vit_small_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_small_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch16_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch8_224.dino': _cfg(
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
# custom timm variants
'vit_base_patch16_rpn_224.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
hf_hub_id='timm/'),
'vit_medium_patch16_gap_240.in12k': _cfg(
hf_hub_id='timm/vit_medium_patch16_gap_240.in12k',
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k',
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
'vit_base_patch16_gap_224': _cfg(),
@ -808,24 +918,24 @@ default_cfgs = generate_default_cfgs({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='',
@ -833,33 +943,33 @@ default_cfgs = generate_default_cfgs({
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
@ -867,58 +977,58 @@ default_cfgs = generate_default_cfgs({
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/clip_vit_base_patch32_224.openai',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.openai': _cfg(
hf_hub_id='timm/clip_vit_base_patch16_224.openai',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_clip_224.openai': _cfg(
hf_hub_id='timm/clip_vit_large_patch14_224.openai',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
@ -926,10 +1036,10 @@ default_cfgs = generate_default_cfgs({
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
# experimental (may be removed)
@ -942,21 +1052,81 @@ default_cfgs = generate_default_cfgs({
# EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
# https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 196, 196), crop_pct=1.0),
'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_large_patch14_196.in22k_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 196, 196), crop_pct=1.0),
'eva_large_patch14_336.in22k_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'flexivit_small.1200ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_small.600ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_small.300ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_base.1200ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_base.600ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_base.300ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_base.1000ep_in21k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'flexivit_base.300ep_in21k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'flexivit_large.1200ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_large.600ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_large.300ep_in1k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95),
'flexivit_base.patch16_in21k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'flexivit_base.patch30_in21k': _cfg(
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
})
@ -964,9 +1134,16 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
if 'flexi' in variant:
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
else:
_filter_fn = checkpoint_filter_fn
return build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
pretrained_filter_fn=_filter_fn,
**kwargs,
)
@ -1396,3 +1573,30 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
return model
@register_model
def flexivit_small(pretrained=False, **kwargs):
""" FlexiViT-Small
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs)
return model
@register_model
def flexivit_base(pretrained=False, **kwargs):
""" FlexiViT-Base
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs)
return model
@register_model
def flexivit_large(pretrained=False, **kwargs):
""" FlexiViT-Large
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs)
return model

@ -27,72 +27,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
),
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True,
),
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
})
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
@ -166,6 +100,83 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
return backbone
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
),
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
),
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
})
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.

@ -11,12 +11,12 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -24,216 +24,6 @@ __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'vit_relpos_base_patch32_plus_rpn_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
'vit_relpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_srelpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
'vit_srelpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
'vit_relpos_medium_patch16_cls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
'vit_relpos_base_patch16_cls_224': _cfg(
url=''),
'vit_relpos_base_patch16_clsgap_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
'vit_relpos_medium_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
}
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index.contiguous()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
class RelPosMlp(nn.Module):
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
super().__init__()
@ -513,6 +303,57 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'vit_relpos_base_patch32_plus_rpn_256.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_small_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth',
hf_hub_id='timm/'),
'vit_relpos_medium_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth',
hf_hub_id='timm/'),
'vit_srelpos_small_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth',
hf_hub_id='timm/'),
'vit_srelpos_medium_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth',
hf_hub_id='timm/'),
'vit_relpos_medium_patch16_cls_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_cls_224.untrained': _cfg(),
'vit_relpos_base_patch16_clsgap_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth',
hf_hub_id='timm/'),
'vit_relpos_small_patch16_rpn_224.untrained': _cfg(),
'vit_relpos_medium_patch16_rpn_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_rpn_224.untrained': _cfg(),
})
@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token

@ -1 +1 @@
__version__ = '0.8.1dev0'
__version__ = '0.8.2dev0'

Loading…
Cancel
Save