diff --git a/.gitignore b/.gitignore index e5142b32..9f8f33d9 100644 --- a/.gitignore +++ b/.gitignore @@ -106,6 +106,16 @@ output/ *.tar *.pth *.pt +*.torch *.gz Untitled.ipynb Testing notebook.ipynb + +# Root dir exclusions +/*.csv +/*.yaml +/*.json +/*.jpg +/*.png +/*.zip +/*.tar.* \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index 2392a190..b6a61727 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*' + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 757c2e5d..a4c18e39 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -151,7 +151,7 @@ def create_dataset( elif name.startswith('hfds/'): # NOTE right now, HF datasets default arrow format is a random-access Dataset, # There will be a IterableDataset variant too, TBD - ds = ImageDataset(root, reader=name, split=split, **kwargs) + ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs) elif name.startswith('tfds/'): ds = IterableImageDataset( root, diff --git a/timm/data/readers/reader_factory.py b/timm/data/readers/reader_factory.py index 58ff56cd..226e3857 100644 --- a/timm/data/readers/reader_factory.py +++ b/timm/data/readers/reader_factory.py @@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar def create_reader(name, root, split='train', **kwargs): name = name.lower() - name = name.split('/', 2) + name = name.split('/', 1) prefix = '' if len(name) > 1: prefix = name[0] diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index 901cf4bc..62ae5f4d 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -13,13 +13,14 @@ try: except ImportError as e: print("Please install Hugging Face datasets package `pip install datasets`.") exit(1) +from .class_map import load_class_map from .reader import Reader -def get_class_labels(info): +def get_class_labels(info, label_key='label'): if 'label' not in info.features: return {} - class_label = info.features['label'] + class_label = info.features[label_key] class_to_idx = {n: class_label.str2int(n) for n in class_label.names} return class_to_idx @@ -32,6 +33,7 @@ class ReaderHfds(Reader): name, split='train', class_map=None, + label_key='label', download=False, ): """ @@ -43,12 +45,17 @@ class ReaderHfds(Reader): name, # 'name' maps to path arg in hf datasets split=split, cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path - #use_auth_token=True, ) # leave decode for caller, plus we want easy access to original path names... self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False)) - self.class_to_idx = get_class_labels(self.dataset.info) + self.label_key = label_key + self.remap_class = False + if class_map: + self.class_to_idx = load_class_map(class_map) + self.remap_class = True + else: + self.class_to_idx = get_class_labels(self.dataset.info, self.label_key) self.split_info = self.dataset.info.splits[split] self.num_samples = self.split_info.num_examples @@ -60,7 +67,10 @@ class ReaderHfds(Reader): else: assert 'path' in image and image['path'] image = open(image['path'], 'rb') - return image, item['label'] + label = item[self.label_key] + if self.remap_class: + label = self.class_to_idx[label] + return image, label def __len__(self): return len(self.dataset) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 21c641b6..03c4d8eb 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -1,6 +1,7 @@ from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier from .cond_conv2d import CondConv2d, get_condconv_initializer @@ -30,8 +31,12 @@ from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same -from .patch_embed import PatchEmbed +from .patch_embed import PatchEmbed, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d +from .pos_embed import resample_abs_pos_embed +from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords +from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \ + FourierEmbed, RotaryEmbedding from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvNormAct diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index a13a6881..765efa08 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn from .helpers import to_2tuple -from .pos_embed import apply_rot_embed, RotaryEmbedding +from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding from .weight_init import trunc_normal_ diff --git a/timm/layers/helpers.py b/timm/layers/helpers.py index 2fa296bc..bc75ef3e 100644 --- a/timm/layers/helpers.py +++ b/timm/layers/helpers.py @@ -10,7 +10,7 @@ import collections.abc def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return x + return tuple(x) return tuple(repeat(x, n)) return parse diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index be8740ce..b7416260 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -2,15 +2,24 @@ A convolution based approach to patchifying a 2D image w/ embedding projection. -Based on the impl in https://github.com/google-research/vision_transformer +Based on code in: + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision/tree/main/big_vision Hacked together by / Copyright 2020 Ross Wightman """ +import logging +from typing import List + +import torch from torch import nn as nn +import torch.nn.functional as F from .helpers import to_2tuple from .trace_utils import _assert +_logger = logging.getLogger(__name__) + class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding @@ -46,3 +55,122 @@ class PatchEmbed(nn.Module): x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x + + +def resample_patch_embed( + patch_embed, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + """Resample the weights of the patch embedding kernel to target resolution. + We resample the patch embedding kernel by approximately inverting the effect + of patch resizing. + + Code based on: + https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py + + With this resizing, we can for example load a B/8 filter into a B/16 model + and, on 2x larger input image, the result will match. + + Args: + patch_embed: original parameter to be resized. + new_size (tuple(int, int): target shape (height, width)-only. + interpolation (str): interpolation for resize + antialias (bool): use anti-aliasing filter in resize + verbose (bool): log operation + Returns: + Resized patch embedding kernel. + """ + import numpy as np + + assert len(patch_embed.shape) == 4, "Four dimensions expected" + assert len(new_size) == 2, "New shape should only be hw" + old_size = patch_embed.shape[-2:] + if tuple(old_size) == tuple(new_size): + return patch_embed + + if verbose: + _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.") + + def resize(x_np, _new_size): + x_tf = torch.Tensor(x_np)[None, None, ...] + x_upsampled = F.interpolate( + x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy() + return x_upsampled + + def get_resize_mat(_old_size, _new_size): + mat = [] + for i in range(np.prod(_old_size)): + basis_vec = np.zeros(_old_size) + basis_vec[np.unravel_index(i, _old_size)] = 1. + mat.append(resize(basis_vec, _new_size).reshape(-1)) + return np.stack(mat).T + + resize_mat = get_resize_mat(old_size, new_size) + resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T)) + + def resample_kernel(kernel): + resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) + return resampled_kernel.reshape(new_size) + + v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1) + return v_resample_kernel(patch_embed) + + +# def divs(n, m=None): +# m = m or n // 2 +# if m == 1: +# return [1] +# if n % m == 0: +# return [m] + divs(n, m - 1) +# return divs(n, m - 1) +# +# +# class FlexiPatchEmbed(nn.Module): +# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT) +# FIXME WIP +# """ +# def __init__( +# self, +# img_size=240, +# patch_size=16, +# in_chans=3, +# embed_dim=768, +# base_img_size=240, +# base_patch_size=32, +# norm_layer=None, +# flatten=True, +# bias=True, +# ): +# super().__init__() +# self.img_size = to_2tuple(img_size) +# self.patch_size = to_2tuple(patch_size) +# self.num_patches = 0 +# +# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48) +# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30) +# +# self.base_img_size = to_2tuple(base_img_size) +# self.base_patch_size = to_2tuple(base_patch_size) +# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)]) +# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1] +# +# self.flatten = flatten +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias) +# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() +# +# def forward(self, x): +# B, C, H, W = x.shape +# +# if self.patch_size == self.base_patch_size: +# weight = self.proj.weight +# else: +# weight = resample_patch_embed(self.proj.weight, self.patch_size) +# patch_size = self.patch_size +# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size) +# if self.flatten: +# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC +# x = self.norm(x) +# return x diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 99a122a0..d0e67521 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -1,207 +1,52 @@ +""" Position Embedding Utilities + +Hacked together by / Copyright 2022 Ross Wightman +""" +import logging import math from typing import List, Tuple, Optional, Union import torch -from torch import nn as nn - - -def pixel_freq_bands( - num_bands: int, - max_freq: float = 224., - linear_bands: bool = True, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, -): - if linear_bands: - bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) - else: - bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) - return bands * torch.pi - - -def inv_freq_bands( - num_bands: int, - temperature: float = 100000., - step: int = 2, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, -) -> torch.Tensor: - inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) - return inv_freq - - -def build_sincos2d_pos_embed( - feat_shape: List[int], - dim: int = 64, - temperature: float = 10000., - reverse_coord: bool = False, - interleave_sin_cos: bool = False, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None -) -> torch.Tensor: - """ - - Args: - feat_shape: - dim: - temperature: - reverse_coord: stack grid order W, H instead of H, W - interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos - dtype: - device: - - Returns: - - """ - assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' - pos_dim = dim // 4 - bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) - - if reverse_coord: - feat_shape = feat_shape[::-1] # stack W, H instead of H, W - grid = torch.stack( - torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) - pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) - # FIXME add support for unflattened spatial dim? - - stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos - pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) - return pos_emb - - -def build_fourier_pos_embed( - feat_shape: List[int], - bands: Optional[torch.Tensor] = None, - num_bands: int = 64, - max_res: int = 224, - linear_bands: bool = False, - include_grid: bool = False, - concat_out: bool = True, - in_pixels: bool = True, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, -) -> List[torch.Tensor]: - if bands is None: - if in_pixels: - bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) - else: - bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype - - if in_pixels: - grid = torch.stack(torch.meshgrid( - [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) - else: - grid = torch.stack(torch.meshgrid( - [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) - grid = grid.unsqueeze(-1) - pos = grid * bands - - pos_sin, pos_cos = pos.sin(), pos.cos() - out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) - # FIXME torchscript doesn't like multiple return types, probably need to always cat? - if concat_out: - out = torch.cat(out, dim=-1) - return out - - -class FourierEmbed(nn.Module): - - def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): - super().__init__() - self.max_res = max_res - self.num_bands = num_bands - self.concat_grid = concat_grid - self.keep_spatial = keep_spatial - self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) - - def forward(self, x): - B, C = x.shape[:2] - feat_shape = x.shape[2:] - emb = build_fourier_pos_embed( - feat_shape, - self.bands, - include_grid=self.concat_grid, - dtype=x.dtype, - device=x.device) - emb = emb.transpose(-1, -2).flatten(len(feat_shape)) - batch_expand = (B,) + (-1,) * (x.ndim - 1) - - # FIXME support nD - if self.keep_spatial: - x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) - else: - x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) - x = x.reshape(B, feat_shape.numel(), -1) - - return x - +import torch.nn.functional as F -def rot(x): - return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) +from .helpers import to_2tuple +_logger = logging.getLogger(__name__) -def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): - return x * cos_emb + rot(x) * sin_emb - -def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): - if isinstance(x, torch.Tensor): - x = [x] - return [t * cos_emb + rot(t) * sin_emb for t in x] - - -def apply_rot_embed_split(x: torch.Tensor, emb): - split = emb.shape[-1] // 2 - return x * emb[:, :split] + rot(x) * emb[:, split:] - - -def build_rotary_pos_embed( - feat_shape: List[int], - bands: Optional[torch.Tensor] = None, - dim: int = 64, - max_freq: float = 224, - linear_bands: bool = False, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, +def resample_abs_pos_embed( + posemb, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, ): - """ - NOTE: shape arg should include spatial dim only - """ - feat_shape = torch.Size(feat_shape) - - sin_emb, cos_emb = build_fourier_pos_embed( - feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands, - concat_out=False, device=device, dtype=dtype) - N = feat_shape.numel() - sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) - cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) - return sin_emb, cos_emb - - -class RotaryEmbedding(nn.Module): - """ Rotary position embedding - - NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not - been well tested, and will likely change. It will be moved to its own file. + # sort out sizes, assume square if old size not provided + new_size = to_2tuple(new_size) + new_ntok = new_size[0] * new_size[1] + if not old_size: + old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens)) + old_size = to_2tuple(old_size) + if new_size == old_size: # might not both be same container type + return posemb + + if num_prefix_tokens: + posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] + else: + posemb_prefix, posemb = None, posemb - The following impl/resources were referenced for this impl: - * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py - * https://blog.eleuther.ai/rotary-embeddings/ - """ - def __init__(self, dim, max_res=224, linear_bands: bool = False): - super().__init__() - self.dim = dim - self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + # do the interpolation + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1) - def get_embed(self, shape: List[int]): - return build_rotary_pos_embed(shape, self.bands) + if verbose: + _logger.info(f'Resized position embedding: {old_size} to {new_size}.') - def forward(self, x): - # assuming channel-first tensor where spatial dim are >= 2 - sin_emb, cos_emb = self.get_embed(x.shape[2:]) - return apply_rot_embed(x, sin_emb, cos_emb) + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + print(posemb_prefix.shape, posemb.shape) + posemb = torch.cat([posemb_prefix, posemb], dim=1) + return posemb diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py new file mode 100644 index 00000000..2ef25670 --- /dev/null +++ b/timm/layers/pos_embed_rel.py @@ -0,0 +1,283 @@ +""" Relative position embedding modules and functions + +Hacked together by / Copyright 2022 Ross Wightman +""" +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .mlp import Mlp +from .weight_init import trunc_normal_ + + +def gen_relative_position_index( + q_size: Tuple[int, int], + k_size: Tuple[int, int] = None, + class_token: bool = False) -> torch.Tensor: + # Adapted with significant modifications from Swin / BeiT codebases + # get pair-wise relative position index for each token inside the window + q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww + if k_size is None: + k_coords = q_coords + k_size = q_size + else: + # different q vs k sizes is a WIP + k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) + relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 + _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + + if class_token: + # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias + # NOTE not intended or tested with MLP log-coords + max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) + num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 + relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + + return relative_position_index.contiguous() + + +class RelPosBias(nn.Module): + """ Relative Position Bias + Adapted from Swin-V1 relative position bias impl, modularized. + """ + + def __init__(self, window_size, num_heads, prefix_tokens=0): + super().__init__() + assert prefix_tokens <= 1 + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.register_buffer( + "relative_position_index", + gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), + persistent=False, + ) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + # win_h * win_w, win_h * win_w, num_heads + relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1) + return relative_position_bias.unsqueeze(0).contiguous() + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +def gen_relative_log_coords( + win_size: Tuple[int, int], + pretrained_win_size: Tuple[int, int] = (0, 0), + mode='swin', +): + assert mode in ('swin', 'cr', 'rw') + # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well + relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 + if mode == 'swin': + if pretrained_win_size[0] > 0: + relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1) + relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1) + else: + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / math.log2(8) + else: + if mode == 'rw': + # cr w/ window size normalization -> [-1,1] log coords + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # scale to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) + relative_coords_table /= math.log2(9) # -> [-1, 1] + else: + # mode == 'cr' + relative_coords_table = torch.sign(relative_coords_table) * torch.log( + 1.0 + relative_coords_table.abs()) + + return relative_coords_table + + +class RelPosMlp(nn.Module): + """ Log-Coordinate Relative Position MLP + Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883) + + This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw') + """ + def __init__( + self, + window_size, + num_heads=8, + hidden_dim=128, + prefix_tokens=0, + mode='cr', + pretrained_window_size=(0, 0) + ): + super().__init__() + self.window_size = window_size + self.window_area = self.window_size[0] * self.window_size[1] + self.prefix_tokens = prefix_tokens + self.num_heads = num_heads + self.bias_shape = (self.window_area,) * 2 + (num_heads,) + if mode == 'swin': + self.bias_act = nn.Sigmoid() + self.bias_gain = 16 + mlp_bias = (True, False) + elif mode == 'rw': + self.bias_act = nn.Tanh() + self.bias_gain = 4 + mlp_bias = True + else: + self.bias_act = nn.Identity() + self.bias_gain = None + mlp_bias = True + + self.mlp = Mlp( + 2, # x, y + hidden_features=hidden_dim, + out_features=num_heads, + act_layer=nn.ReLU, + bias=mlp_bias, + drop=(0.125, 0.) + ) + + self.register_buffer( + "relative_position_index", + gen_relative_position_index(window_size), + persistent=False) + + # get relative_coords_table + self.register_buffer( + "rel_coords_log", + gen_relative_log_coords(window_size, pretrained_window_size, mode=mode), + persistent=False) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.mlp(self.rel_coords_log) + if self.relative_position_index is not None: + relative_position_bias = relative_position_bias.view(-1, self.num_heads)[ + self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.view(self.bias_shape) + relative_position_bias = relative_position_bias.permute(2, 0, 1) + relative_position_bias = self.bias_act(relative_position_bias) + if self.bias_gain is not None: + relative_position_bias = self.bias_gain * relative_position_bias + if self.prefix_tokens: + relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) + return relative_position_bias.unsqueeze(0).contiguous() + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +def generate_lookup_tensor( + length: int, + max_relative_position: Optional[int] = None, +): + """Generate a one_hot lookup tensor to reindex embeddings along one dimension. + + Args: + length: the length to reindex to. + max_relative_position: the maximum relative position to consider. + Relative position embeddings for distances above this threshold + are zeroed out. + Returns: + a lookup Tensor of size [length, length, vocab_size] that satisfies + ret[n,m,v] = 1{m - n + max_relative_position = v}. + """ + if max_relative_position is None: + max_relative_position = length - 1 + # Return the cached lookup tensor, otherwise compute it and cache it. + vocab_size = 2 * max_relative_position + 1 + ret = torch.zeros(length, length, vocab_size) + for i in range(length): + for x in range(length): + v = x - i + max_relative_position + if abs(x - i) > max_relative_position: + continue + ret[i, x, v] = 1 + return ret + + +def reindex_2d_einsum_lookup( + relative_position_tensor, + height: int, + width: int, + height_lookup: torch.Tensor, + width_lookup: torch.Tensor, +) -> torch.Tensor: + """Reindex 2d relative position bias with 2 independent einsum lookups. + + Adapted from: + https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py + + Args: + relative_position_tensor: tensor of shape + [..., vocab_height, vocab_width, ...]. + height: height to reindex to. + width: width to reindex to. + height_lookup: one-hot height lookup + width_lookup: one-hot width lookup + Returns: + reindexed_tensor: a Tensor of shape + [..., height * width, height * width, ...] + """ + reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup) + reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup) + area = height * width + return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area) + + +class RelPosBiasTf(nn.Module): + """ Relative Position Bias Impl (Compatible with Tensorflow MaxViT models) + Adapted from: + https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py + """ + def __init__(self, window_size, num_heads, prefix_tokens=0): + super().__init__() + assert prefix_tokens <= 1 + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.num_heads = num_heads + + vocab_height = 2 * window_size[0] - 1 + vocab_width = 2 * window_size[1] - 1 + self.bias_shape = (self.num_heads, vocab_height, vocab_width) + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape)) + self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False) + self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False) + self.init_weights() + + def init_weights(self): + nn.init.normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + # FIXME change to not use one-hot/einsum? + return reindex_2d_einsum_lookup( + self.relative_position_bias_table, + self.window_size[0], + self.window_size[1], + self.height_lookup, + self.width_lookup + ) + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py new file mode 100644 index 00000000..5603a5cd --- /dev/null +++ b/timm/layers/pos_embed_sincos.py @@ -0,0 +1,219 @@ +""" Sin-cos, fourier, rotary position embedding modules and functions + +Hacked together by / Copyright 2022 Ross Wightman +""" +import math +from typing import List, Tuple, Optional, Union + +import torch +from torch import nn as nn + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224., + linear_bands: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + else: + bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device) + return bands * torch.pi + + +def inv_freq_bands( + num_bands: int, + temperature: float = 100000., + step: int = 2, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands)) + return inv_freq + + +def build_sincos2d_pos_embed( + feat_shape: List[int], + dim: int = 64, + temperature: float = 10000., + reverse_coord: bool = False, + interleave_sin_cos: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None +) -> torch.Tensor: + """ + + Args: + feat_shape: + dim: + temperature: + reverse_coord: stack grid order W, H instead of H, W + interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos + dtype: + device: + + Returns: + + """ + assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding' + pos_dim = dim // 4 + bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device) + + if reverse_coord: + feat_shape = feat_shape[::-1] # stack W, H instead of H, W + grid = torch.stack( + torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1) + pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) + # FIXME add support for unflattened spatial dim? + + stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos + pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) + return pos_emb + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + linear_bands: bool = False, + include_grid: bool = False, + concat_out: bool = True, + in_pixels: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + if bands is None: + if in_pixels: + bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device) + else: + bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + else: + grid = torch.stack(torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin(), pos.cos() + out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) + # FIXME torchscript doesn't like multiple return types, probably need to always cat? + if concat_out: + out = torch.cat(out, dim=-1) + return out + + +class FourierEmbed(nn.Module): + + def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False): + super().__init__() + self.max_res = max_res + self.num_bands = num_bands + self.concat_grid = concat_grid + self.keep_spatial = keep_spatial + self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False) + + def forward(self, x): + B, C = x.shape[:2] + feat_shape = x.shape[2:] + emb = build_fourier_pos_embed( + feat_shape, + self.bands, + include_grid=self.concat_grid, + dtype=x.dtype, + device=x.device) + emb = emb.transpose(-1, -2).flatten(len(feat_shape)) + batch_expand = (B,) + (-1,) * (x.ndim - 1) + + # FIXME support nD + if self.keep_spatial: + x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1) + else: + x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1) + x = x.reshape(B, feat_shape.numel(), -1) + + return x + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +def apply_rot_embed_split(x: torch.Tensor, emb): + split = emb.shape[-1] // 2 + return x * emb[:, :split] + rot(x) * emb[:, split:] + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_freq: float = 224, + linear_bands: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + NOTE: shape arg should include spatial dim only + """ + feat_shape = torch.Size(feat_shape) + + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, + bands=bands, + num_bands=dim // 4, + max_res=max_freq, + linear_bands=linear_bands, + concat_out=False, + device=device, + dtype=dtype, + ) + N = feat_shape.numel() + sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + + def __init__(self, dim, max_res=224, linear_bands: bool = False): + super().__init__() + self.dim = dim + self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False) + + def get_embed(self, shape: List[int]): + return build_rotary_pos_embed(shape, self.bands) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index f634650e..901d7d44 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -1,5 +1,6 @@ import dataclasses import logging +import os from copy import deepcopy from typing import Optional, Dict, Callable, Any, Tuple @@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url from timm.models._features import FeatureListNet, FeatureHookNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict -from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf +from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf from timm.models._manipulate import adapt_input_conv from timm.models._pretrained import PretrainedCfg from timm.models._prune import adapt_model_from_file @@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg): pretrained_url = pretrained_cfg.get('url', None) pretrained_file = pretrained_cfg.get('file', None) hf_hub_id = pretrained_cfg.get('hf_hub_id', None) + # resolve where to load pretrained weights from load_from = '' pretrained_loc = '' @@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg): else: # default source == timm or unspecified if pretrained_file: + # file load override is the highest priority if set load_from = 'file' pretrained_loc = pretrained_file - elif pretrained_url: - load_from = 'url' - pretrained_loc = pretrained_url - elif hf_hub_id and has_hf_hub(necessary=True): - # hf-hub available as alternate weight source in default_cfg - load_from = 'hf-hub' - pretrained_loc = hf_hub_id + else: + # next, HF hub is prioritized unless a valid cached version of weights exists already + cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False + if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid: + # hf-hub available as alternate weight source in default_cfg + load_from = 'hf-hub' + pretrained_loc = hf_hub_id + elif pretrained_url: + load_from = 'url' + pretrained_loc = pretrained_url + if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None): # if a filename override is set, return tuple for location w/ (hub_id, filename) pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename'] @@ -105,7 +112,7 @@ def load_custom_pretrained( pretrained_loc = download_cached_file( pretrained_loc, check_hash=_CHECK_HASH, - progress=_DOWNLOAD_PROGRESS + progress=_DOWNLOAD_PROGRESS, ) if load_fn is not None: @@ -146,12 +153,21 @@ def load_pretrained( state_dict = load_state_dict(pretrained_loc) elif load_from == 'url': _logger.info(f'Loading pretrained weights from url ({pretrained_loc})') - state_dict = load_state_dict_from_url( - pretrained_loc, - map_location='cpu', - progress=_DOWNLOAD_PROGRESS, - check_hash=_CHECK_HASH, - ) + if pretrained_cfg.get('custom_load', False): + pretrained_loc = download_cached_file( + pretrained_loc, + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) + model.load_pretrained(pretrained_loc) + return + else: + state_dict = load_state_dict_from_url( + pretrained_loc, + map_location='cpu', + progress=_DOWNLOAD_PROGRESS, + check_hash=_CHECK_HASH, + ) elif load_from == 'hf-hub': _logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})') if isinstance(pretrained_loc, (list, tuple)): @@ -364,20 +380,14 @@ def build_model_with_cfg( # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: - if pretrained_cfg.get('custom_load', False): - load_custom_pretrained( - model, - pretrained_cfg=pretrained_cfg, - ) - else: - load_pretrained( - model, - pretrained_cfg=pretrained_cfg, - num_classes=num_classes_pretrained, - in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, - strict=pretrained_strict, - ) + load_pretrained( + model, + pretrained_cfg=pretrained_cfg, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict, + ) # Wrap the model in a feature extraction module if enabled if features: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index e6b7d558..7c64df0b 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import os @@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False): return cached_file +def check_cached_file(url, check_hash=True): + if isinstance(url, (list, tuple)): + url, filename = url + else: + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if os.path.exists(cached_file): + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + if hash_prefix: + with open(cached_file, 'rb') as f: + hd = hashlib.sha256(f.read()).hexdigest() + if hd[:len(hash_prefix)] != hash_prefix: + return False + return True + return False + + def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error @@ -90,14 +111,14 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]): return json.loads(text) -def _download_from_hf(model_id: str, filename: str): +def download_from_hf(model_id: str, filename: str): hf_model_id, hf_revision = hf_split(model_id) return hf_hub_download(hf_model_id, filename, revision=hf_revision) def load_model_config_from_hf(model_id: str): assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, 'config.json') + cached_file = download_from_hf(model_id, 'config.json') hf_config = load_cfg_from_json(cached_file) if 'pretrained_cfg' not in hf_config: @@ -124,34 +145,28 @@ def load_model_config_from_hf(model_id: str): def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): assert has_hf_hub(True) - cached_file = _download_from_hf(model_id, filename) + cached_file = download_from_hf(model_id, filename) state_dict = torch.load(cached_file, map_location='cpu') return state_dict -def save_for_hf(model, save_directory, model_config=None): - assert has_hf_hub(True) +def save_config_for_hf(model, config_path, model_config=None): model_config = model_config or {} - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - - weights_path = save_directory / 'pytorch_model.bin' - torch.save(model.state_dict(), weights_path) - - config_path = save_directory / 'config.json' hf_config = {} pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) # set some values at root config level hf_config['architecture'] = pretrained_cfg.pop('architecture') hf_config['num_classes'] = model_config.get('num_classes', model.num_classes) hf_config['num_features'] = model_config.get('num_features', model.num_features) - hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None)) + global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None)) + if isinstance(global_pool_type, str) and global_pool_type: + hf_config['global_pool'] = global_pool_type - if 'label' in model_config: + if 'labels' in model_config: _logger.warning( - "'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " + "'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " "Using provided 'label' field as 'label_name'.") - model_config['label_name'] = model_config.pop('label') + model_config['label_name'] = model_config.pop('labels') label_name = model_config.pop('label_name', None) if label_name: @@ -173,6 +188,18 @@ def save_for_hf(model, save_directory, model_config=None): json.dump(hf_config, f, indent=2) +def save_for_hf(model, save_directory, model_config=None): + assert has_hf_hub(True) + save_directory = Path(save_directory) + save_directory.mkdir(exist_ok=True, parents=True) + + weights_path = save_directory / 'pytorch_model.bin' + torch.save(model.state_dict(), weights_path) + + config_path = save_directory / 'config.json' + save_config_for_hf(model, config_path, model_config=model_config) + + def push_to_hf_hub( model, repo_id: str, diff --git a/timm/models/_pretrained.py b/timm/models/_pretrained.py index b5ecbc50..dca81eb0 100644 --- a/timm/models/_pretrained.py +++ b/timm/models/_pretrained.py @@ -19,6 +19,7 @@ class PretrainedCfg: source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub) architecture: Optional[str] = None # architecture variant can be set when not implicit + tag: Optional[str] = None # pretrained tag of source custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files) # input / data config @@ -44,9 +45,11 @@ class PretrainedCfg: classifier: Optional[str] = None license: Optional[str] = None - source_url: Optional[str] = None - paper: Optional[str] = None - notes: Optional[str] = None + description: Optional[str] = None + origin_url: Optional[str] = None + paper_name: Optional[str] = None + paper_ids: Optional[Union[str, Tuple[str]]] = None + notes: Optional[Tuple[str]] = None @property def has_weights(self): @@ -62,11 +65,11 @@ class PretrainedCfg: def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True): filtered_cfg = {} - keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none + keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none for k, v in cfg.items(): if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}: continue - if remove_null and v is None and k not in keep_none: + if remove_null and v is None and k not in keep_null: continue filtered_cfg[k] = v return filtered_cfg diff --git a/timm/models/_registry.py b/timm/models/_registry.py index fc7b3437..80eb2e94 100644 --- a/timm/models/_registry.py +++ b/timm/models/_registry.py @@ -7,6 +7,7 @@ import re import sys from collections import defaultdict, deque from copy import deepcopy +from dataclasses import replace from typing import List, Optional, Union, Tuple from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag @@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names _model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_has_pretrained = set() # set of model names that have pretrained weight url present _model_default_cfgs = dict() # central repo for model arch -> default cfg objects -_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs +_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs _model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names @@ -48,24 +49,31 @@ def register_model(fn): if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos - cfg = mod.default_cfgs[model_name] - if not isinstance(cfg, DefaultCfg): + default_cfg = mod.default_cfgs[model_name] + if not isinstance(default_cfg, DefaultCfg): # new style default cfg dataclass w/ multiple entries per model-arch - assert isinstance(cfg, dict) + assert isinstance(default_cfg, dict) # old style cfg dict per model-arch - cfg = PretrainedCfg(**cfg) - cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg}) + pretrained_cfg = PretrainedCfg(**default_cfg) + default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg}) - for tag_idx, tag in enumerate(cfg.tags): + for tag_idx, tag in enumerate(default_cfg.tags): is_default = tag_idx == 0 - pretrained_cfg = cfg.cfgs[tag] + pretrained_cfg = default_cfg.cfgs[tag] + model_name_tag = '.'.join([model_name, tag]) if tag else model_name + replace_items = dict(architecture=model_name, tag=tag if tag else None) + if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/': + # auto-complete hub name w/ architecture.tag + replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag + pretrained_cfg = replace(pretrained_cfg, **replace_items) + if is_default: _model_pretrained_cfgs[model_name] = pretrained_cfg if pretrained_cfg.has_weights: # add tagless entry if it's default and has weights _model_has_pretrained.add(model_name) + if tag: - model_name_tag = '.'.join([model_name, tag]) _model_pretrained_cfgs[model_name_tag] = pretrained_cfg if pretrained_cfg.has_weights: # add model w/ tag if tag is valid @@ -74,7 +82,7 @@ def register_model(fn): else: _model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances) - _model_default_cfgs[model_name] = cfg + _model_default_cfgs[model_name] = default_cfg return fn @@ -198,15 +206,21 @@ def is_model_pretrained(model_name): return model_name in _model_has_pretrained -def get_pretrained_cfg(model_name): +def get_pretrained_cfg(model_name, allow_unregistered=True): if model_name in _model_pretrained_cfgs: return deepcopy(_model_pretrained_cfgs[model_name]) - raise RuntimeError(f'No pretrained config exists for model {model_name}.') + arch_name, tag = split_model_name_tag(model_name) + if arch_name in _model_default_cfgs: + # if model arch exists, but the tag is wrong, error out + raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.') + if allow_unregistered: + # if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created + return None + raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') def get_pretrained_cfg_value(model_name, cfg_key): """ Get a specific model default_cfg value by key. None if key doesn't exist. """ - if model_name in _model_pretrained_cfgs: - return getattr(_model_pretrained_cfgs[model_name], cfg_key, None) - raise RuntimeError(f'No pretrained config exist for model {model_name}.') \ No newline at end of file + cfg = get_pretrained_cfg(model_name, allow_unregistered=False) + return getattr(cfg, cfg_key, None) diff --git a/timm/models/beit.py b/timm/models/beit.py index de71f441..12ec493d 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -355,64 +355,76 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth', + hf_hub_id='timm/'), 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_base_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', + hf_hub_id='timm/', num_classes=21841, ), 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg( - url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth', + hf_hub_id='timm/'), 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + hf_hub_id='timm/', input_size=(3, 512, 512), crop_pct=1.0, ), 'beit_large_patch16_224.in22k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', + hf_hub_id='timm/', num_classes=21841, ), 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + hf_hub_id='timm/', + num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', - crop_pct=0.95, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + hf_hub_id='timm/', + crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg( url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', - num_classes=21841, - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + hf_hub_id='timm/', + num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ), 'eva_giant_patch14_224.clip_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, ), 'eva_giant_patch14_336.clip_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'), }) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index eea5782a..d30e4137 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs): return model - def _cfg(url='', **kwargs): return { 'url': url, @@ -375,90 +374,131 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ # timm specific variants - 'convnext_atto.timm_in1k': _cfg( + 'convnext_atto.d2_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), - 'convnext_atto_ols.timm_in1k': _cfg( + 'convnext_atto_ols.a2_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), - 'convnext_femto.timm_in1k': _cfg( + 'convnext_femto.d1_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), - 'convnext_femto_ols.timm_in1k': _cfg( + 'convnext_femto_ols.d1_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), - 'convnext_pico.timm_in1k': _cfg( + 'convnext_pico.d1_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=0.95), - 'convnext_pico_ols.timm_in1k': _cfg( + 'convnext_pico_ols.d1_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - 'convnext_nano.timm_in1k': _cfg( + 'convnext_nano.d1h_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - 'convnext_nano_ols.timm_in1k': _cfg( + 'convnext_nano_ols.d1h_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - 'convnext_tiny_hnf.timm_in1k': _cfg( + 'convnext_tiny_hnf.a2h_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), + 'convnext_tiny.fb_in1k': _cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_small.fb_in1k': _cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_base.fb_in1k': _cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_large.fb_in1k': _cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_xlarge.untrained': _cfg(), + 'convnext_xxlarge.untrained': _cfg(), 'convnext_tiny.fb_in22k_ft_in1k': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_small.fb_in22k_ft_in1k': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_base.fb_in22k_ft_in1k': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_large.fb_in22k_ft_in1k': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_xlarge.fb_in22k_ft_in1k': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + hf_hub_id='timm/', test_input_size=(3, 288, 288), test_crop_pct=1.0), 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'convnext_small..fb_in22k_ft_in1k_384': _cfg( + 'convnext_small.fb_in22k_ft_in1k_384': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'convnext_base.fb_in22k_ft_in1k_384': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'convnext_large.fb_in22k_ft_in1k_384': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'convnext_tiny_in22k.fb_in22k': _cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), - 'convnext_small_in22k.fb_in22k': _cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), - 'convnext_base_in22k.fb_in22k': _cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), - 'convnext_large_in22k.fb_in22k': _cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), - 'convnext_xlarge_in22k.fb_in22k': _cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), + 'convnext_tiny.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_small.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_base.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_large.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), + 'convnext_xlarge.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", + hf_hub_id='timm/', + num_classes=21841), }) @@ -576,3 +616,10 @@ def convnext_xlarge(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args) return model + + +@register_model +def convnext_xxlarge(pretrained=False, **kwargs): + model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs) + model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index a1324ae3..a3866fec 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -50,410 +50,12 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs from ._registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv_stem', 'classifier': 'classifier', - **kwargs - } - - -default_cfgs = { - 'mnasnet_050': _cfg(url=''), - 'mnasnet_075': _cfg(url=''), - 'mnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), - 'mnasnet_140': _cfg(url=''), - - 'semnasnet_050': _cfg(url=''), - 'semnasnet_075': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth'), - 'semnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), - 'semnasnet_140': _cfg(url=''), - 'mnasnet_small': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_small_lamb-aff75073.pth'), - - 'mobilenetv2_035': _cfg( - url=''), - 'mobilenetv2_050': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth', - interpolation='bicubic', - ), - 'mobilenetv2_075': _cfg( - url=''), - 'mobilenetv2_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'), - 'mobilenetv2_110d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'), - 'mobilenetv2_120d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'), - 'mobilenetv2_140': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'), - - 'fbnetc_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', - interpolation='bilinear'), - 'spnasnet_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', - interpolation='bilinear'), - - # NOTE experimenting with alternate attention - 'efficientnet_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), - 'efficientnet_b1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', - test_input_size=(3, 256, 256), crop_pct=1.0), - 'efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', - input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), - 'efficientnet_b3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', - input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), - 'efficientnet_b4': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', - input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), - 'efficientnet_b5': _cfg( - url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'efficientnet_b6': _cfg( - url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'efficientnet_b7': _cfg( - url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'efficientnet_b8': _cfg( - url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), - 'efficientnet_l2': _cfg( - url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), - - # FIXME experimental - 'efficientnet_b0_gn': _cfg( - url=''), - 'efficientnet_b0_g8_gn': _cfg( - url=''), - 'efficientnet_b0_g16_evos': _cfg( - url=''), - 'efficientnet_b3_gn': _cfg( - url='', - input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), - 'efficientnet_b3_g8_gn': _cfg( - url='', - input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), - - 'efficientnet_es': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), - 'efficientnet_em': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'efficientnet_el': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el-3b455510.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - - 'efficientnet_es_pruned': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_pruned75-1b7248cf.pth'), - 'efficientnet_el_pruned': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el_pruned70-ef2a2ccf.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - - 'efficientnet_cc_b0_4e': _cfg(url=''), - 'efficientnet_cc_b0_8e': _cfg(url=''), - 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - - 'efficientnet_lite0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth'), - 'efficientnet_lite1': _cfg( - url='', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'efficientnet_lite2': _cfg( - url='', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'efficientnet_lite3': _cfg( - url='', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'efficientnet_lite4': _cfg( - url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - - 'efficientnet_b1_pruned': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb1_pruned-bea43a3a.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'efficientnet_b2_pruned': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb2_pruned-08c1b27c.pth', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'efficientnet_b3_pruned': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb3_pruned-59ecf72d.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - - 'efficientnetv2_rw_t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth', - input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), - 'gc_efficientnetv2_rw_t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth', - input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), - 'efficientnetv2_rw_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', - input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), - 'efficientnetv2_rw_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', - input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), - - 'efficientnetv2_s': _cfg( - url='', - input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), - 'efficientnetv2_m': _cfg( - url='', - input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), - 'efficientnetv2_l': _cfg( - url='', - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), - 'efficientnetv2_xl': _cfg( - url='', - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), - - 'tf_efficientnet_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', - input_size=(3, 224, 224)), - 'tf_efficientnet_b1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'tf_efficientnet_b3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'tf_efficientnet_b4': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'tf_efficientnet_b5': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', - input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'tf_efficientnet_b6': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', - input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'tf_efficientnet_b7': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', - input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'tf_efficientnet_b8': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', - input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), - - 'tf_efficientnet_b0_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), - 'tf_efficientnet_b1_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_b2_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'tf_efficientnet_b3_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'tf_efficientnet_b4_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'tf_efficientnet_b5_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'tf_efficientnet_b6_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'tf_efficientnet_b7_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'tf_efficientnet_b8_ap': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), - - 'tf_efficientnet_b0_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', - input_size=(3, 224, 224)), - 'tf_efficientnet_b1_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_b2_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), - 'tf_efficientnet_b3_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - 'tf_efficientnet_b4_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), - 'tf_efficientnet_b5_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', - input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), - 'tf_efficientnet_b6_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', - input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), - 'tf_efficientnet_b7_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', - input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), - 'tf_efficientnet_l2_ns_475': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', - input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), - 'tf_efficientnet_l2_ns': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', - input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), - - 'tf_efficientnet_es': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), ), - 'tf_efficientnet_em': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - 'tf_efficientnet_el': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), - - 'tf_efficientnet_cc_b0_4e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_efficientnet_cc_b0_8e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_efficientnet_cc_b1_8e': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), - - 'tf_efficientnet_lite0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res - ), - 'tf_efficientnet_lite1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, - interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res - ), - 'tf_efficientnet_lite2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, - interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res - ), - 'tf_efficientnet_lite3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), - 'tf_efficientnet_lite4': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), - - 'tf_efficientnetv2_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), - 'tf_efficientnetv2_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'tf_efficientnetv2_l': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - - 'tf_efficientnetv2_s_in21ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), - 'tf_efficientnetv2_m_in21ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'tf_efficientnetv2_l_in21ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'tf_efficientnetv2_xl_in21ft1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - - 'tf_efficientnetv2_s_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), - 'tf_efficientnetv2_m_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'tf_efficientnetv2_l_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - 'tf_efficientnetv2_xl_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), - - 'tf_efficientnetv2_b0': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', - input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), - 'tf_efficientnetv2_b1': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', - input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), - 'tf_efficientnetv2_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', - input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), - 'tf_efficientnetv2_b3': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', - input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), - - 'mixnet_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), - 'mixnet_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), - 'mixnet_l': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), - 'mixnet_xl': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'), - 'mixnet_xxl': _cfg(), - - 'tf_mixnet_s': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), - 'tf_mixnet_m': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), - 'tf_mixnet_l': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), - - "tinynet_a": _cfg( - input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86) - url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth'), - "tinynet_b": _cfg( - input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84) - url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth'), - "tinynet_c": _cfg( - input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825) - url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth'), - "tinynet_d": _cfg( - input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68) - url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth'), - "tinynet_e": _cfg( - input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475) - url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth'), -} - - class EfficientNet(nn.Module): """ EfficientNet @@ -471,9 +73,23 @@ class EfficientNet(nn.Module): """ def __init__( - self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, - output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, - se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): + self, + block_args, + num_classes=1000, + num_features=1280, + in_chans=3, + stem_size=32, + fix_stem=False, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_rate=0., + drop_path_rate=0., + global_pool='avg' + ): super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -492,8 +108,14 @@ class EfficientNet(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, - act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features head_chs = builder.in_chs @@ -567,9 +189,22 @@ class EfficientNetFeatures(nn.Module): """ def __init__( - self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, - stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, - act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + self, + block_args, + out_indices=(0, 1, 2, 3, 4), + feature_location='bottleneck', + in_chans=3, + stem_size=32, + fix_stem=False, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_rate=0., + drop_path_rate=0. + ): super(EfficientNetFeatures, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -585,9 +220,15 @@ class EfficientNetFeatures(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, - act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, - feature_location=feature_location) + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + feature_location=feature_location, + ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} @@ -1233,23 +874,518 @@ def _gen_tinynet( return model -@register_model -def mnasnet_050(pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 0.5. """ - model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) - return model - +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } -@register_model -def mnasnet_075(pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 0.75. """ - model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) - return model +default_cfgs = generate_default_cfgs({ + 'mnasnet_050.untrained': _cfg(), + 'mnasnet_075.untrained': _cfg(), + 'mnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + hf_hub_id='timm/'), + 'mnasnet_140.untrained': _cfg(), + + 'semnasnet_050.untrained': _cfg(), + 'semnasnet_075.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth', + hf_hub_id='timm/'), + 'semnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + hf_hub_id='timm/'), + 'semnasnet_140.untrained': _cfg(), + 'mnasnet_small.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_small_lamb-aff75073.pth', + hf_hub_id='timm/'), + + 'mobilenetv2_035.untrained': _cfg(), + 'mobilenetv2_050.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + 'mobilenetv2_075.untrained': _cfg(), + 'mobilenetv2_100.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + hf_hub_id='timm/'), + 'mobilenetv2_110d.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + hf_hub_id='timm/'), + 'mobilenetv2_120d.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + hf_hub_id='timm/'), + 'mobilenetv2_140.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + hf_hub_id='timm/'), + + 'fbnetc_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + hf_hub_id='timm/', + interpolation='bilinear'), + 'spnasnet_100.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + hf_hub_id='timm/', + interpolation='bilinear'), -@register_model -def mnasnet_100(pretrained=False, **kwargs): - """ MNASNet B1, depth multiplier of 1.0. """ + # NOTE experimenting with alternate attention + 'efficientnet_b0.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + hf_hub_id='timm/'), + 'efficientnet_b1.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=1.0), + 'efficientnet_b2.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), + 'efficientnet_b3.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + hf_hub_id='timm/', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b4.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', + hf_hub_id='timm/', + input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), + 'efficientnet_b5.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'), + 'efficientnet_b5.in12k': _cfg( + hf_hub_id='timm/', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821), + 'efficientnet_b6.untrained': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7.untrained': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8.untrained': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_l2.untrained': _cfg( + url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + + # FIXME experimental + 'efficientnet_b0_gn.untrained': _cfg(), + 'efficientnet_b0_g8_gn.untrained': _cfg(), + 'efficientnet_b0_g16_evos.untrained': _cfg(), + 'efficientnet_b3_gn.untrained': _cfg( + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b3_g8_gn.untrained': _cfg( + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + + 'efficientnet_es.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + hf_hub_id='timm/'), + 'efficientnet_em.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el-3b455510.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_es_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_pruned75-1b7248cf.pth', + hf_hub_id='timm/'), + 'efficientnet_el_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el_pruned70-ef2a2ccf.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_cc_b0_4e.untrained': _cfg(), + 'efficientnet_cc_b0_8e.untrained': _cfg(), + 'efficientnet_cc_b1_8e.untrained': _cfg(input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + hf_hub_id='timm/'), + 'efficientnet_lite1.untrained': _cfg( + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2.untrained': _cfg( + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3.untrained': _cfg( + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4.untrained': _cfg( + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + + 'efficientnet_b1_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb1_pruned-bea43a3a.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), + crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b2_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb2_pruned-08c1b27c.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), + crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b3_pruned.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb3_pruned-59ecf72d.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), + crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'efficientnetv2_rw_t.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'gc_efficientnetv2_rw_t.agc_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0), + 'efficientnetv2_rw_s.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', + hf_hub_id='timm/', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_rw_m.agc_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', + hf_hub_id='timm/', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + + 'efficientnetv2_s.untrained': _cfg( + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_m.untrained': _cfg( + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + 'efficientnetv2_l.untrained': _cfg( + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'efficientnetv2_xl.untrained': _cfg( + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnet_b0.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.aa_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + hf_hub_id='timm/', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + hf_hub_id='timm/', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8.ap_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + hf_hub_id='timm/', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + hf_hub_id='timm/', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + hf_hub_id='timm/', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + hf_hub_id='timm/', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + hf_hub_id='timm/', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + hf_hub_id='timm/', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + hf_hub_id='timm/', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_l2.ns_jft_in1k_475': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + hf_hub_id='timm/', + input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), + 'tf_efficientnet_l2.ns_jft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + hf_hub_id='timm/', + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + + 'tf_efficientnet_es.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'tf_efficientnet_cc_b0_4e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + + 'tf_efficientnetv2_s.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_s.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_xl.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_s.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_l.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'tf_efficientnetv2_xl.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', + hf_hub_id='timm/', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'tf_efficientnetv2_b0.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', + hf_hub_id='timm/', + input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), + 'tf_efficientnetv2_b1.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', + hf_hub_id='timm/', + input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), + 'tf_efficientnetv2_b2.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', + hf_hub_id='timm/', + input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), + 'tf_efficientnetv2_b3.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.9, crop_mode='squash'), + 'tf_efficientnetv2_b3.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + 'tf_efficientnetv2_b3.in21k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=21843, + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + + 'mixnet_s.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + hf_hub_id='timm/'), + 'mixnet_m.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + hf_hub_id='timm/'), + 'mixnet_l.ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + hf_hub_id='timm/'), + 'mixnet_xl.ra_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + hf_hub_id='timm/'), + 'mixnet_xxl.untrained': _cfg(), + + 'tf_mixnet_s.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + hf_hub_id='timm/'), + 'tf_mixnet_m.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + hf_hub_id='timm/'), + 'tf_mixnet_l.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', + hf_hub_id='timm/'), + + "tinynet_a.in1k": _cfg( + input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth', + hf_hub_id='timm/'), + "tinynet_b.in1k": _cfg( + input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth', + hf_hub_id='timm/'), + "tinynet_c.in1k": _cfg( + input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth', + hf_hub_id='timm/'), + "tinynet_d.in1k": _cfg( + input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth', + hf_hub_id='timm/'), + "tinynet_e.in1k": _cfg( + input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth', + hf_hub_id='timm/'), +}) + + +@register_model +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) return model @@ -1830,199 +1966,13 @@ def tf_efficientnet_b8(pretrained=False, **kwargs): @register_model -def tf_efficientnet_b0_ap(pretrained=False, **kwargs): - """ EfficientNet-B0 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b1_ap(pretrained=False, **kwargs): - """ EfficientNet-B1 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b2_ap(pretrained=False, **kwargs): - """ EfficientNet-B2 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b3_ap(pretrained=False, **kwargs): - """ EfficientNet-B3 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b4_ap(pretrained=False, **kwargs): - """ EfficientNet-B4 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b5_ap(pretrained=False, **kwargs): - """ EfficientNet-B5 AdvProp. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b6_ap(pretrained=False, **kwargs): - """ EfficientNet-B6 AdvProp. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b7_ap(pretrained=False, **kwargs): - """ EfficientNet-B7 AdvProp. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b8_ap(pretrained=False, **kwargs): - """ EfficientNet-B8 AdvProp. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b0_ns(pretrained=False, **kwargs): - """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b1_ns(pretrained=False, **kwargs): - """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b2_ns(pretrained=False, **kwargs): - """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b3_ns(pretrained=False, **kwargs): - """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b4_ns(pretrained=False, **kwargs): - """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b5_ns(pretrained=False, **kwargs): - """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b6_ns(pretrained=False, **kwargs): - """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_b7_ns(pretrained=False, **kwargs): - """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): - """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """ - # NOTE for train, drop_rate should be 0.5 - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnet( - 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnet_l2_ns(pretrained=False, **kwargs): +def tf_efficientnet_l2(pretrained=False, **kwargs): """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """ # NOTE for train, drop_rate should be 0.5 kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' model = _gen_efficientnet( - 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + 'tf_efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) return model @@ -2146,7 +2096,6 @@ def tf_efficientnet_lite4(pretrained=False, **kwargs): return model - @register_model def tf_efficientnetv2_s(pretrained=False, **kwargs): """ EfficientNet-V2 Small. Tensorflow compatible variant """ @@ -2175,82 +2124,12 @@ def tf_efficientnetv2_l(pretrained=False, **kwargs): @register_model -def tf_efficientnetv2_s_in21ft1k(pretrained=False, **kwargs): - """ EfficientNet-V2 Small. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21ft1k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_m_in21ft1k(pretrained=False, **kwargs): - """ EfficientNet-V2 Medium. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_l_in21ft1k(pretrained=False, **kwargs): - """ EfficientNet-V2 Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21ft1k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_xl_in21ft1k(pretrained=False, **kwargs): - """ EfficientNet-V2 Xtra-Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21ft1k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_s_in21k(pretrained=False, **kwargs): - """ EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_m_in21k(pretrained=False, **kwargs): - """ EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_l_in21k(pretrained=False, **kwargs): - """ EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant - """ - kwargs['bn_eps'] = BN_EPS_TF_DEFAULT - kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21k', pretrained=pretrained, **kwargs) - return model - - -@register_model -def tf_efficientnetv2_xl_in21k(pretrained=False, **kwargs): - """ EfficientNet-V2 Xtra-Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant +def tf_efficientnetv2_xl(pretrained=False, **kwargs): + """ EfficientNet-V2 Xtra-Large. Tensorflow compatible variant """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' - model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21k', pretrained=pretrained, **kwargs) + model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl', pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 97e70563..dd5b27d9 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -2,6 +2,7 @@ from timm.layers.activations import * from timm.layers.adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from timm.layers.blur_pool import BlurPool2d from timm.layers.classifier import ClassifierHead, create_classifier from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 1e2666e5..1170e7e3 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -47,16 +47,15 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm -from timm.layers import SelectAdaptivePool2d, create_pool2d -from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d -from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert +from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d +from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq from ._pretrained import generate_default_cfgs from ._registry import register_model -from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] @@ -1076,93 +1075,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): return cfg -def generate_lookup_tensor( - length: int, - max_relative_position: Optional[int] = None, -): - """Generate a one_hot lookup tensor to reindex embeddings along one dimension. - Args: - length: the length to reindex to. - max_relative_position: the maximum relative position to consider. - Relative position embeddings for distances above this threshold - are zeroed out. - Returns: - a lookup Tensor of size [length, length, vocab_size] that satisfies - ret[n,m,v] = 1{m - n + max_relative_position = v}. - """ - if max_relative_position is None: - max_relative_position = length - 1 - # Return the cached lookup tensor, otherwise compute it and cache it. - vocab_size = 2 * max_relative_position + 1 - ret = torch.zeros(length, length, vocab_size) - for i in range(length): - for x in range(length): - v = x - i + max_relative_position - if abs(x - i) > max_relative_position: - continue - ret[i, x, v] = 1 - return ret - - -def reindex_2d_einsum_lookup( - relative_position_tensor, - height: int, - width: int, - height_lookup: torch.Tensor, - width_lookup: torch.Tensor, -) -> torch.Tensor: - """Reindex 2d relative position bias with 2 independent einsum lookups. - Args: - relative_position_tensor: tensor of shape - [..., vocab_height, vocab_width, ...]. - height: height to reindex to. - width: width to reindex to. - height_lookup: one-hot height lookup - width_lookup: one-hot width lookup - Returns: - reindexed_tensor: a Tensor of shape - [..., height * width, height * width, ...] - """ - reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup) - reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup) - area = height * width - return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area) - - -class RelPosBiasTf(nn.Module): - - def __init__(self, window_size, num_heads, prefix_tokens=0): - super().__init__() - assert prefix_tokens <= 1 - self.window_size = window_size - self.window_area = window_size[0] * window_size[1] - self.num_heads = num_heads - - vocab_height = 2 * window_size[0] - 1 - vocab_width = 2 * window_size[1] - 1 - self.bias_shape = (self.num_heads, vocab_height, vocab_width) - self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape)) - self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False) - self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False) - self.init_weights() - - def init_weights(self): - nn.init.normal_(self.relative_position_bias_table, std=.02) - - def get_bias(self) -> torch.Tensor: - # FIXME change to not use one-hot/einsum? - return reindex_2d_einsum_lookup( - self.relative_position_bias_table, - self.window_size[0], - self.window_size[1], - self.height_lookup, - self.width_lookup - ) - - def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): - return attn + self.get_bias() - - class NormMlpHead(nn.Module): def __init__( diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index cf4f268d..e1da91a2 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -21,93 +21,12 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs from ._registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bilinear', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv_stem', 'classifier': 'classifier', - **kwargs - } - - -default_cfgs = { - 'mobilenetv3_large_075': _cfg(url=''), - 'mobilenetv3_large_100': _cfg( - interpolation='bicubic', - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), - 'mobilenetv3_large_100_miil': _cfg( - interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'), - 'mobilenetv3_large_100_miil_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth', - interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221), - - 'mobilenetv3_small_050': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', - interpolation='bicubic'), - 'mobilenetv3_small_075': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth', - interpolation='bicubic'), - 'mobilenetv3_small_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth', - interpolation='bicubic'), - - 'mobilenetv3_rw': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', - interpolation='bicubic'), - - 'tf_mobilenetv3_large_075': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_mobilenetv3_large_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_mobilenetv3_large_minimal_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_mobilenetv3_small_075': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_mobilenetv3_small_100': _cfg( - url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'tf_mobilenetv3_small_minimal_100': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - - 'fbnetv3_b': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth', - test_input_size=(3, 256, 256), crop_pct=0.95), - 'fbnetv3_d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth', - test_input_size=(3, 256, 256), crop_pct=0.95), - 'fbnetv3_g': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth', - input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)), - - "lcnet_035": _cfg(), - "lcnet_050": _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth', - interpolation='bicubic', - ), - "lcnet_075": _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth', - interpolation='bicubic', - ), - "lcnet_100": _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth', - interpolation='bicubic', - ), - "lcnet_150": _cfg(), -} - - class MobileNetV3(nn.Module): """ MobiletNet-V3 @@ -124,9 +43,24 @@ class MobileNetV3(nn.Module): """ def __init__( - self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280, - head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, - round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): + self, + block_args, + num_classes=1000, + in_chans=3, + stem_size=16, + fix_stem=False, + num_features=1280, + head_bias=True, + pad_type='', + act_layer=None, + norm_layer=None, + se_layer=None, + se_from_exp=True, + round_chs_fn=round_channels, + drop_rate=0., + drop_path_rate=0., + global_pool='avg', + ): super(MobileNetV3, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -145,8 +79,15 @@ class MobileNetV3(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, - act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + output_stride=32, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features head_chs = builder.in_chs @@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module): """ def __init__( - self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, - stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, - se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + self, + block_args, + out_indices=(0, 1, 2, 3, 4), + feature_location='bottleneck', + in_chans=3, + stem_size=16, + fix_stem=False, + output_stride=32, + pad_type='', + round_chs_fn=round_channels, + se_from_exp=True, + act_layer=None, + norm_layer=None, + se_layer=None, + drop_rate=0., + drop_path_rate=0., + ): super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -243,9 +198,16 @@ class MobileNetV3Features(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, - act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, - drop_path_rate=drop_path_rate, feature_location=feature_location) + output_stride=output_stride, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + feature_location=feature_location, + ) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} @@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs): kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') model_cls = MobileNetV3Features model = build_model_with_cfg( - model_cls, variant, pretrained, + model_cls, + variant, + pretrained, pretrained_strict=not features_only, kwargs_filter=kwargs_filter, **kwargs) @@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'mobilenetv3_large_075.untrained': _cfg(url=''), + 'mobilenetv3_large_100.ra_in1k': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + hf_hub_id='timm/'), + 'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg( + interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), + origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', + paper_ids='arXiv:2104.10972v4', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth', + hf_hub_id='timm/'), + 'mobilenetv3_large_100.miil_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth', + hf_hub_id='timm/', + origin_url='https://github.com/Alibaba-MIIL/ImageNet21K', + paper_ids='arXiv:2104.10972v4', + interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221), + + 'mobilenetv3_small_050.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv3_small_075.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + 'mobilenetv3_small_100.lamb_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth', + hf_hub_id='timm/', + interpolation='bicubic'), + + 'mobilenetv3_rw.rmsp_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + + 'tf_mobilenetv3_large_075.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100.in1k': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100.in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'fbnetv3_b.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_d.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth', + hf_hub_id='timm/', + test_input_size=(3, 256, 256), crop_pct=0.95), + 'fbnetv3_g.ra2_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth', + hf_hub_id='timm/', + input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)), + + "lcnet_035.untrained": _cfg(), + "lcnet_050.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_075.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_100.ra2_in1k": _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth', + hf_hub_id='timm/', + interpolation='bicubic', + ), + "lcnet_150.untrained": _cfg(), +}) + + @register_model def mobilenetv3_large_075(pretrained=False, **kwargs): """ MobileNet V3 """ @@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs): return model -@register_model -def mobilenetv3_large_100_miil(pretrained=False, **kwargs): - """ MobileNet V3 - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K - """ - model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs) - return model - - -@register_model -def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs): - """ MobileNet V3, 21k pretraining - Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K - """ - model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs) - return model - - @register_model def mobilenetv3_small_050(pretrained=False, **kwargs): """ MobileNet V3 """ diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5b93628f..d6865549 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -8,14 +8,18 @@ A PyTorch implement of Vision Transformers as described in: `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - https://arxiv.org/abs/2106.10270 -The official jax code is released and available at https://github.com/google-research/vision_transformer +`FlexiViT: One Model for All Patch Sizes` + - https://arxiv.org/abs/2212.08013 + +The official jax code is released and available at + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision Acknowledgments: -* The paper authors for releasing code and weights, thanks! -* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out -for some einops/einsum fun -* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT -* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + * The paper authors for releasing code and weights, thanks! + * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch + * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT + * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020, Ross Wightman """ @@ -23,7 +27,7 @@ import logging import math from collections import OrderedDict from functools import partial -from typing import Optional +from typing import Optional, List import torch import torch.nn as nn @@ -32,7 +36,8 @@ import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ + resample_abs_pos_embed from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._pretrained import generate_default_cfgs @@ -449,6 +454,39 @@ def get_init_weights_vit(mode='jax', head_bias: float = 0.): return init_weights_vit_timm +def resize_pos_embed( + posemb, + posemb_new, + num_prefix_tokens=1, + gs_new=(), + interpolation='bicubic', + antialias=False, +): + """ Rescale the grid of position embeddings when loading from state_dict. + + *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed + + Adapted from: + https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + """ + ntok_new = posemb_new.shape[1] + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens + else: + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) + return posemb + + @torch.no_grad() def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): """ Load weights from .npz checkpoints for official Google Brain Flax implementation @@ -468,8 +506,15 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = return torch.from_numpy(w) w = np.load(checkpoint_path) - if not prefix and 'opt/target/embedding/kernel' in w: - prefix = 'opt/target/' + interpolation = 'bilinear' + antialias = False + big_vision = False + if not prefix: + if 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + elif 'params/embedding/kernel' in w: + prefix = 'params/' + big_vision = True if hasattr(model.patch_embed, 'backbone'): # hybrid @@ -495,17 +540,33 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = else: embed_conv_w = adapt_input_conv( model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + model.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) if model.cls_token is not None: model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if big_vision: + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) + else: + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: - pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + old_shape = pos_embed_w.shape + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, - model.pos_embed, - getattr(model, 'num_prefix_tokens', 1), - model.patch_embed.grid_size + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) @@ -517,9 +578,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ @@ -529,32 +591,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) - - -def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): - # Rescale the grid of position embeddings when loading from state_dict. Adapted from - # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 - _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) - ntok_new = posemb_new.shape[1] - if num_prefix_tokens: - posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] - ntok_new -= num_prefix_tokens - else: - posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] - gs_old = int(math.sqrt(len(posemb_grid))) - if not len(gs_new): # backwards compatibility - gs_new = [int(math.sqrt(ntok_new))] * 2 - assert len(gs_new) >= 2 - _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) - return posemb + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) def _convert_openai_clip(state_dict, model): @@ -591,7 +631,13 @@ def _convert_openai_clip(state_dict, model): return out_dict -def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): +def checkpoint_filter_fn( + state_dict, + model, + adapt_layer_scale=False, + interpolation='bicubic', + antialias=True, +): """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} @@ -603,17 +649,30 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): return _convert_openai_clip(state_dict, model) for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k and len(v.shape) < 4: - # For old models that I trained prior to conv based patchification + if 'patch_embed.proj.weight' in k: O, I, H, W = model.patch_embed.proj.weight.shape - v = v.reshape(O, -1, H, W) + if len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed( + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( v, - model.pos_embed, - 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), - model.patch_embed.grid_size + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, ) elif adapt_layer_scale and 'gamma_' in k: # remap layer-scale gamma into sub-module (deit3 models) @@ -641,67 +700,101 @@ default_cfgs = generate_default_cfgs({ # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), # re-finetuned augreg 21k FT on in1k weights 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( - file='b16_augreg-a-8.pth'), - 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg( - url=''), + hf_hub_id='timm/'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( - url=''), + hf_hub_id='timm/'), # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + hf_hub_id='timm/'), 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=1.0), - # How to train your ViT (augreg) weights trained on in1k + # How to train your ViT (augreg) weights trained on in1k only + 'vit_small_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch16_224.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', custom_load=True), 'vit_base_patch16_384.augreg_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch14_224.untrained': _cfg(url=''), @@ -711,77 +804,94 @@ default_cfgs = generate_default_cfgs({ # patch models, imagenet21k (weights from official Google JAX impl) - 'vit_large_patch32_224.v1_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', - num_classes=21843), - 'vit_huge_patch14_224.v1_in21k': _cfg( + 'vit_large_patch32_224.orig_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + hf_hub_id='timm/', + num_classes=21843), + 'vit_huge_patch14_224.orig_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', - hf_hub_id='timm/vit_huge_patch14_224_in21k', + hf_hub_id='timm/', custom_load=True, num_classes=21843), # How to train your ViT (augreg) weights, pretrained on in21k 'vit_tiny_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_small_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_small_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch32_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_base_patch8_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), 'vit_large_patch16_224.augreg_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + hf_hub_id='timm/', custom_load=True, num_classes=21843), # SAM trained models (https://arxiv.org/abs/2106.01548) 'vit_base_patch32_224.sam': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True), + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True, + hf_hub_id='timm/'), 'vit_base_patch16_224.sam': _cfg( - url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True), + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True, + hf_hub_id='timm/'), # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) 'vit_small_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_small_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch16_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), 'vit_base_patch8_224.dino': _cfg( url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', + hf_hub_id='timm/', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', + hf_hub_id='timm/', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), # custom timm variants 'vit_base_patch16_rpn_224.in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', + hf_hub_id='timm/'), 'vit_medium_patch16_gap_240.in12k': _cfg( - hf_hub_id='timm/vit_medium_patch16_gap_240.in12k', + hf_hub_id='timm/', input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), 'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg( - hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k', + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), 'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg( - hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k', + hf_hub_id='timm/', input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), 'vit_base_patch16_gap_224': _cfg(), @@ -808,24 +918,24 @@ default_cfgs = generate_default_cfgs({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k', + hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k', + hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='', @@ -833,33 +943,33 @@ default_cfgs = generate_default_cfgs({ crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), @@ -867,58 +977,58 @@ default_cfgs = generate_default_cfgs({ #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k', + hf_hub_id='timm/', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( - hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), 'vit_base_patch32_clip_224.openai': _cfg( - hf_hub_id='timm/clip_vit_base_patch32_224.openai', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_base_patch16_clip_224.openai': _cfg( - hf_hub_id='timm/clip_vit_base_patch16_224.openai', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), 'vit_large_patch14_clip_224.openai': _cfg( - hf_hub_id='timm/clip_vit_large_patch14_224.openai', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), @@ -926,10 +1036,10 @@ default_cfgs = generate_default_cfgs({ #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( - hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( - hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), # experimental (may be removed) @@ -942,21 +1052,81 @@ default_cfgs = generate_default_cfgs({ # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), 'eva_large_patch14_196.in22k_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0), 'eva_large_patch14_336.in22k_ft_in1k': _cfg( - hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + + 'flexivit_small.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.1000ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.300ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'flexivit_large.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.patch16_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.patch30_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), }) @@ -964,9 +1134,16 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = checkpoint_filter_fn + return build_model_with_cfg( VisionTransformer, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=_filter_fn, **kwargs, ) @@ -1396,3 +1573,30 @@ def eva_large_patch14_336(pretrained=False, **kwargs): patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def flexivit_small(pretrained=False, **kwargs): + """ FlexiViT-Small + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs) + model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def flexivit_base(pretrained=False, **kwargs): + """ FlexiViT-Base + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs) + model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def flexivit_large(pretrained=False, **kwargs): + """ FlexiViT-Large + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs) + model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index cfdd0a0e..bec7989c 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -27,72 +27,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem from .vision_transformer import _create_vision_transformer -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', - **kwargs - } - - -default_cfgs = generate_default_cfgs({ - # hybrid in-1k models (weights from official JAX impl where they exist) - 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', - custom_load=True, - first_conv='patch_embed.backbone.conv'), - 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), - 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', - custom_load=True, - ), - 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), - 'vit_base_r26_s32_224.untrained': _cfg(), - 'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', - custom_load=True, - ), - 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', - input_size=(3, 384, 384), crop_pct=1.0, custom_load=True, - ), - - # hybrid in-21k models (weights from official Google JAX impl where they exist) - 'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True), - 'vit_small_r26_s32_224.augreg_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9, custom_load=True), - 'vit_base_r50_s16_224.v1_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', - num_classes=21843, crop_pct=0.9), - 'vit_large_r50_s32_224.augreg_in21k': _cfg( - url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', - num_classes=21843, crop_pct=0.9, custom_load=True), - - # hybrid models (using timm resnet backbones) - 'vit_small_resnet26d_224': _cfg( - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), - 'vit_small_resnet50d_s16_224': _cfg( - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), - 'vit_base_resnet26d_224': _cfg( - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), - 'vit_base_resnet50d_224': _cfg( - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), -}) - - class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. @@ -166,6 +100,83 @@ def _resnetv2(layers=(3, 4, 9), **kwargs): return backbone +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # hybrid in-1k models (weights from official JAX impl where they exist) + 'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True, + first_conv='patch_embed.backbone.conv'), + 'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), + 'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True, + ), + 'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0, custom_load=True), + 'vit_base_r26_s32_224.untrained': _cfg(), + 'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True, + ), + 'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0, custom_load=True, + ), + + # hybrid in-21k models (weights from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True), + 'vit_small_r26_s32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + num_classes=21843, crop_pct=0.9, custom_load=True), + 'vit_base_r50_s16_224.orig_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + hf_hub_id='timm/', + num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + num_classes=21843, crop_pct=0.9, custom_load=True), + + # hybrid models (using timm resnet backbones) + 'vit_small_resnet26d_224.untrained': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_small_resnet50d_s16_224.untrained': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet26d_224.untrained': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet50d_224.untrained': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), +}) + + @register_model def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 1a7c2f40..a7cf3e53 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -11,12 +11,12 @@ from typing import Optional, Tuple import torch import torch.nn as nn -import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias from ._builder import build_model_with_cfg +from ._pretrained import generate_default_cfgs from ._registry import register_model __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this @@ -24,216 +24,6 @@ __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint _logger = logging.getLogger(__name__) -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -default_cfgs = { - 'vit_relpos_base_patch32_plus_rpn_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', - input_size=(3, 256, 256)), - 'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)), - - 'vit_relpos_small_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'), - 'vit_relpos_medium_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'), - 'vit_relpos_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), - - 'vit_srelpos_small_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'), - 'vit_srelpos_medium_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'), - - 'vit_relpos_medium_patch16_cls_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'), - 'vit_relpos_base_patch16_cls_224': _cfg( - url=''), - 'vit_relpos_base_patch16_clsgap_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), - - 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), - 'vit_relpos_medium_patch16_rpn_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'), - 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), -} - - -def gen_relative_position_index( - q_size: Tuple[int, int], - k_size: Tuple[int, int] = None, - class_token: bool = False) -> torch.Tensor: - # Adapted with significant modifications from Swin / BeiT codebases - # get pair-wise relative position index for each token inside the window - q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww - if k_size is None: - k_coords = q_coords - k_size = q_size - else: - # different q vs k sizes is a WIP - k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1) - relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2 - _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) - - if class_token: - # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias - # NOTE not intended or tested with MLP log-coords - max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1])) - num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3 - relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 - - return relative_position_index.contiguous() - - -def gen_relative_log_coords( - win_size: Tuple[int, int], - pretrained_win_size: Tuple[int, int] = (0, 0), - mode='swin', -): - assert mode in ('swin', 'cr', 'rw') - # as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well - relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) - relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) - relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 - if mode == 'swin': - if pretrained_win_size[0] > 0: - relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1) - relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1) - else: - relative_coords_table[:, :, 0] /= (win_size[0] - 1) - relative_coords_table[:, :, 1] /= (win_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) / math.log2(8) - else: - if mode == 'rw': - # cr w/ window size normalization -> [-1,1] log coords - relative_coords_table[:, :, 0] /= (win_size[0] - 1) - relative_coords_table[:, :, 1] /= (win_size[1] - 1) - relative_coords_table *= 8 # scale to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - 1.0 + relative_coords_table.abs()) - relative_coords_table /= math.log2(9) # -> [-1, 1] - else: - # mode == 'cr' - relative_coords_table = torch.sign(relative_coords_table) * torch.log( - 1.0 + relative_coords_table.abs()) - - return relative_coords_table - - -class RelPosMlp(nn.Module): - def __init__( - self, - window_size, - num_heads=8, - hidden_dim=128, - prefix_tokens=0, - mode='cr', - pretrained_window_size=(0, 0) - ): - super().__init__() - self.window_size = window_size - self.window_area = self.window_size[0] * self.window_size[1] - self.prefix_tokens = prefix_tokens - self.num_heads = num_heads - self.bias_shape = (self.window_area,) * 2 + (num_heads,) - if mode == 'swin': - self.bias_act = nn.Sigmoid() - self.bias_gain = 16 - mlp_bias = (True, False) - elif mode == 'rw': - self.bias_act = nn.Tanh() - self.bias_gain = 4 - mlp_bias = True - else: - self.bias_act = nn.Identity() - self.bias_gain = None - mlp_bias = True - - self.mlp = Mlp( - 2, # x, y - hidden_features=hidden_dim, - out_features=num_heads, - act_layer=nn.ReLU, - bias=mlp_bias, - drop=(0.125, 0.) - ) - - self.register_buffer( - "relative_position_index", - gen_relative_position_index(window_size), - persistent=False) - - # get relative_coords_table - self.register_buffer( - "rel_coords_log", - gen_relative_log_coords(window_size, pretrained_window_size, mode=mode), - persistent=False) - - def get_bias(self) -> torch.Tensor: - relative_position_bias = self.mlp(self.rel_coords_log) - if self.relative_position_index is not None: - relative_position_bias = relative_position_bias.view(-1, self.num_heads)[ - self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.view(self.bias_shape) - relative_position_bias = relative_position_bias.permute(2, 0, 1) - relative_position_bias = self.bias_act(relative_position_bias) - if self.bias_gain is not None: - relative_position_bias = self.bias_gain * relative_position_bias - if self.prefix_tokens: - relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0]) - return relative_position_bias.unsqueeze(0).contiguous() - - def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): - return attn + self.get_bias() - - -class RelPosBias(nn.Module): - - def __init__(self, window_size, num_heads, prefix_tokens=0): - super().__init__() - assert prefix_tokens <= 1 - self.window_size = window_size - self.window_area = window_size[0] * window_size[1] - self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,) - - num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens - self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) - self.register_buffer( - "relative_position_index", - gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0), - persistent=False, - ) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.relative_position_bias_table, std=.02) - - def get_bias(self) -> torch.Tensor: - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] - # win_h * win_w, win_h * win_w, num_heads - relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1) - return relative_position_bias.unsqueeze(0).contiguous() - - def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): - return attn + self.get_bias() - - class RelPosAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): super().__init__() @@ -513,6 +303,57 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): return model +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + 'vit_relpos_base_patch32_plus_rpn_256.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', + hf_hub_id='timm/', + input_size=(3, 256, 256)), + 'vit_relpos_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240)), + + 'vit_relpos_small_patch16_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth', + hf_hub_id='timm/'), + 'vit_relpos_medium_patch16_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth', + hf_hub_id='timm/'), + 'vit_relpos_base_patch16_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth', + hf_hub_id='timm/'), + + 'vit_srelpos_small_patch16_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth', + hf_hub_id='timm/'), + 'vit_srelpos_medium_patch16_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth', + hf_hub_id='timm/'), + + 'vit_relpos_medium_patch16_cls_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth', + hf_hub_id='timm/'), + 'vit_relpos_base_patch16_cls_224.untrained': _cfg(), + 'vit_relpos_base_patch16_clsgap_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth', + hf_hub_id='timm/'), + + 'vit_relpos_small_patch16_rpn_224.untrained': _cfg(), + 'vit_relpos_medium_patch16_rpn_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth', + hf_hub_id='timm/'), + 'vit_relpos_base_patch16_rpn_224.untrained': _cfg(), +}) + + @register_model def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token diff --git a/timm/version.py b/timm/version.py index 0716d38a..c9cc324d 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.1dev0' +__version__ = '0.8.2dev0'