Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/1239/head
Ross Wightman 3 years ago
commit 809c7bb1ec

@ -16,9 +16,9 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macOS-latest] os: [ubuntu-latest, macOS-latest]
python: ['3.8'] python: ['3.9']
torch: ['1.9.0'] torch: ['1.10.0']
torchvision: ['0.10.0'] torchvision: ['0.11.1']
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -30,7 +30,7 @@ jobs:
- name: Install testing dependencies - name: Install testing dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install pytest pytest-timeout pip install pytest pytest-timeout expecttest
- name: Install torch on mac - name: Install torch on mac
if: startsWith(matrix.os, 'macOS') if: startsWith(matrix.os, 'macOS')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}

@ -4,9 +4,16 @@ import platform
import os import os
import fnmatch import fnmatch
try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
import timm import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
get_model_default_value get_model_default_value
from timm.models.fx_features import _leaf_modules, _autowrap_functions
if hasattr(torch._C, '_jit_set_profiling_executor'): if hasattr(torch._C, '_jit_set_profiling_executor'):
# legacy executor is too slow to compile large models for unit tests # legacy executor is too slow to compile large models for unit tests
@ -38,6 +45,10 @@ TARGET_JIT_SIZE = 128
MAX_JIT_SIZE = 320 MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96 TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256 MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224
TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224
def _get_input_size(model=None, model_name='', target=None): def _get_input_size(model=None, model_name='', target=None):
@ -297,3 +308,135 @@ def test_model_forward_features(model_name, batch_size):
assert e == o.shape[1] assert e == o.shape[1]
assert o.shape[0] == batch_size assert o.shape[0] == batch_size
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
return fx_model
EXCLUDE_FX_FILTERS = []
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'swin_large*',
'*resnext101_32x32d',
'resnetv2_152x2*',
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
model = torch.jit.script(_create_fx_model(model))
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -1,7 +1,11 @@
import os import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
Places365, ImageNet, ImageFolder try:
from torchvision.datasets import Places365
has_places365 = True
except ImportError:
has_places365 = False
try: try:
from torchvision.datasets import INaturalist from torchvision.datasets import INaturalist
has_inaturalist = True has_inaturalist = True
@ -104,6 +108,7 @@ def create_dataset(
split = '2021_valid' split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365': elif name == 'places365':
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
if split in _TRAIN_SYNONYM: if split in _TRAIN_SYNONYM:
split = 'train-standard' split = 'train-standard'
elif split in _EVAL_SYNONYM: elif split in _EVAL_SYNONYM:

@ -36,7 +36,7 @@ PREFETCH_SIZE = 2048 # examples to prefetch
def even_split_indices(split, n, num_examples): def even_split_indices(split, n, num_examples):
partitions = [round(i * num_examples / n) for i in range(n + 1)] partitions = [round(i * num_examples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
def get_class_labels(info): def get_class_labels(info):
@ -70,6 +70,7 @@ class ParserTfds(Parser):
components. components.
""" """
def __init__( def __init__(
self, self,
root, root,
@ -99,6 +100,7 @@ class ParserTfds(Parser):
download: download and build TFDS dataset if set, otherwise must use tfds CLI download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances seed: common seed for shard shuffle across all distributed/worker instances
input_name: name of Feature to return as data (input)
input_image: image mode if input is an image (currently PIL mode string) input_image: image mode if input is an image (currently PIL mode string)
target_name: name of Feature to return as target (label) target_name: name of Feature to return as target (label)
target_image: image mode if target is an image (currently PIL mode string) target_image: image mode if target is an image (currently PIL mode string)
@ -111,7 +113,7 @@ class ParserTfds(Parser):
self.split = split self.split = split
self.is_training = is_training self.is_training = is_training
if self.is_training: if self.is_training:
assert batch_size is not None,\ assert batch_size is not None, \
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size self.batch_size = batch_size
self.repeats = repeats self.repeats = repeats
@ -184,7 +186,7 @@ class ParserTfds(Parser):
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong. between the splits each iteration, but that understanding could be wrong.
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or

@ -86,9 +86,11 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias: if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else: else:
self.q_bias = None self.q_bias = None
self.k_bias = None
self.v_bias = None self.v_bias = None
if window_size: if window_size:
@ -127,13 +129,7 @@ class Attention(nn.Module):
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv_bias = None qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
if self.q_bias is not None:
if torch.jit.is_scripting():
# FIXME requires_grad breaks w/ torchscript
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
else:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

@ -36,6 +36,9 @@ default_cfgs = {
'botnet26t_256': _cfg( 'botnet26t_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'sebotnet33ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'botnet50ts_256': _cfg( 'botnet50ts_256': _cfg(
url='', url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
@ -51,7 +54,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg( 'halonet50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h_256-c6d7ff15.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg( 'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
@ -97,6 +100,22 @@ model_cfgs = dict(
self_attn_layer='bottleneck', self_attn_layer='bottleneck',
self_attn_kwargs=dict() self_attn_kwargs=dict()
), ),
sebotnet33ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
act_layer='silu',
num_features=1280,
attn_layer='se',
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
),
botnet50ts=ByoModelCfg( botnet50ts=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
@ -322,6 +341,13 @@ def botnet26t_256(pretrained=False, **kwargs):
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
@register_model
def sebotnet33ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
"""
return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
@register_model @register_model
def botnet50ts_256(pretrained=False, **kwargs): def botnet50ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act. """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.

@ -35,7 +35,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d
from .registry import register_model from .registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@ -136,20 +136,26 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'),
# experimental models, likely to change ot be removed # experimental models, likely to change ot be removed
'regnetz_b': _cfgr( 'regnetz_b16': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94),
'regnetz_c': _cfgr( 'regnetz_c16': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94),
'regnetz_d': _cfgr( 'regnetz_d32': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
'regnetz_d8': _cfgr( 'regnetz_d8': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
'regnetz_e8': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
'regnetz_d8_evob': _cfgr(
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
'regnetz_e8': _cfgr( 'regnetz_d8_evos': _cfgr(
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
} }
@ -506,7 +512,7 @@ model_cfgs = dict(
), ),
# experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
regnetz_b=ByoModelCfg( regnetz_b16=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
@ -522,7 +528,7 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
), ),
regnetz_c=ByoModelCfg( regnetz_c16=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
@ -538,7 +544,7 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
), ),
regnetz_d=ByoModelCfg( regnetz_d32=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4), ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4), ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
@ -589,8 +595,45 @@ model_cfgs = dict(
attn_kwargs=dict(rd_ratio=0.25), attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True), block_kwargs=dict(bottle_in=True, linear_out=True),
), ),
)
# experimental EvoNorm configs
regnetz_d8_evob=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
downsample='',
num_features=1792,
act_layer='silu',
norm_layer='evonormbatch',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_d8_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
),
stem_chs=64,
stem_type='deep',
stem_pool='',
downsample='',
num_features=1792,
act_layer='silu',
norm_layer=partial(EvoNormSample2d, groups=32),
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
)
@register_model @register_model
def gernet_l(pretrained=False, **kwargs): def gernet_l(pretrained=False, **kwargs):
@ -779,24 +822,24 @@ def gcresnext50ts(pretrained=False, **kwargs):
@register_model @register_model
def regnetz_b(pretrained=False, **kwargs): def regnetz_b16(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs)
@register_model @register_model
def regnetz_c(pretrained=False, **kwargs): def regnetz_c16(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs)
@register_model @register_model
def regnetz_d(pretrained=False, **kwargs): def regnetz_d32(pretrained=False, **kwargs):
""" """
""" """
return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs)
@register_model @register_model
@ -813,6 +856,20 @@ def regnetz_e8(pretrained=False, **kwargs):
return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs) return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
@register_model
def regnetz_d8_evob(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs)
@register_model
def regnetz_d8_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence): if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,) stage_blocks_cfg = (stage_blocks_cfg,)

@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model from .registry import register_model
from .layers import _assert
__all__ = [ __all__ = [
@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module):
def forward(self, q, v, size: Tuple[int, int]): def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape B, h, N, Ch = q.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
# Convolutional relative position encoding. # Convolutional relative position encoding.
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module):
def forward(self, x, size: Tuple[int, int]): def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
# Extract CLS token and image tokens. # Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
@ -275,7 +276,7 @@ class ParallelBlock(nn.Module):
""" Feature map interpolation. """ """ Feature map interpolation. """
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W _assert(N == 1 + H * W, '')
cls_token = x[:, :1, :] cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]

@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed from .vision_transformer_hybrid import HybridEmbed
from .fx_features import register_notrace_module
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -56,6 +57,7 @@ default_cfgs = {
} }
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module): class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.): locality_strength=1.):

@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
""" """
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -31,8 +32,9 @@ from functools import partial
from typing import List from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_ from .layers import DropPath, to_2tuple, trunc_normal_, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import Mlp, Block from .vision_transformer import Mlp, Block
@ -116,8 +118,10 @@ class PatchEmbed(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
# FIXME look at relaxing size constraints # FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \ _assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x).flatten(2).transpose(1, 2) x = self.proj(x).flatten(2).transpose(1, 2)
return x return x
@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches):
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
@register_notrace_function
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
"""
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
Args:
x (Tensor): input image
ss (tuple[int, int]): height and width to scale to
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
Returns:
Tensor: the "scaled" image batch tensor
"""
H, W = x.shape[-2:]
if H != ss[0] or W != ss[1]:
if crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
return x
class CrossViT(nn.Module): class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
""" """
@ -342,17 +367,12 @@ class CrossViT(nn.Module):
range(self.num_branches)]) range(self.num_branches)])
def forward_features(self, x): def forward_features(self, x):
B, C, H, W = x.shape B = x.shape[0]
xs = [] xs = []
for i, patch_embed in enumerate(self.patch_embed): for i, patch_embed in enumerate(self.patch_embed):
x_ = x x_ = x
ss = self.img_size_scaled[i] ss = self.img_size_scaled[i]
if H != ss[0] or W != ss[1]: x_ = scale_image(x_, ss, self.crop_scale)
if self.crop_scale and ss[0] <= H and ss[1] <= W:
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]]
else:
x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False)
x_ = patch_embed(x_) x_ = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1) cls_tokens = cls_tokens.expand(B, -1, -1)

@ -0,0 +1,73 @@
""" PyTorch FX Based Feature Extraction Helpers
Using https://pytorch.org/vision/stable/feature_extraction.html
"""
from typing import Callable
from torch import nn
from .features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
}
try:
from .layers import InplaceAbn
_leaf_modules.add(InplaceAbn)
except ImportError:
pass
def register_notrace_module(module: nn.Module):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def register_notrace_function(func: Callable):
"""
Decorator for functions which ought not to be traced through
"""
_autowrap_functions.add(func)
return func
class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)
return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
for i, info in enumerate(self.feature_info) if i in out_indices}
self.graph_module = create_feature_extractor(
model, return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
def forward(self, x):
return list(self.graph_module(x).values())

@ -14,6 +14,7 @@ import torch.nn as nn
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
@ -477,6 +478,8 @@ def build_model_with_cfg(
feature_cls = feature_cls.lower() feature_cls = feature_cls.lower()
if 'hook' in feature_cls: if 'hook' in feature_cls:
feature_cls = FeatureHookNet feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else: else:
assert False, f'Unknown feature class {feature_cls}' assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg) model = feature_cls(model, **feature_cfg)

@ -22,6 +22,7 @@ import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.pos_embed.height _assert(H == self.pos_embed.height, '')
assert W == self.pos_embed.width _assert(W == self.pos_embed.width, '')
x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module):
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out) out = self.pool(out)
return out return out

@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch import torch
import torch.nn as nn import torch.nn as nn
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module): class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
@ -19,12 +21,10 @@ class EvoNormBatch2d(nn.Module):
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum self.momentum = momentum
self.eps = eps self.eps = eps
param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
if apply_act: self.register_buffer('running_var', torch.ones(num_features))
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -36,33 +36,32 @@ class EvoNormBatch2d(nn.Module):
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype x_type = x.dtype
running_var = self.running_var.view(1, -1, 1, 1)
if self.training: if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
self.running_var.copy_( running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) self.running_var.copy_(running_var.view(self.running_var.shape))
else: else:
var = self.running_var var = running_var
if self.apply_act: if self.v is not None:
v = self.v.to(dtype=x_type) v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type)) d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d x = x / d
return x * self.weight + self.bias return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class EvoNormSample2d(nn.Module): class EvoNormSample2d(nn.Module):
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__() super(EvoNormSample2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups self.groups = groups
self.eps = eps self.eps = eps
param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -72,12 +71,12 @@ class EvoNormSample2d(nn.Module):
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' _assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape B, C, H, W = x.shape
assert C % self.groups == 0 _assert(C % self.groups == 0, '')
if self.apply_act: if self.v is not None:
n = x * (x * self.v).sigmoid() n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
x = x.reshape(B, self.groups, -1) x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W) x = x.reshape(B, C, H, W)
return x * self.weight + self.bias return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)

@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented.
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
from typing import Tuple, List from typing import List
import torch import torch
from torch import nn from torch import nn
@ -24,6 +24,7 @@ import torch.nn.functional as F
from .helpers import make_divisible from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
from .trace_utils import _assert
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -167,8 +168,8 @@ class HaloAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H % self.block_size == 0 _assert(H % self.block_size == 0, '')
assert W % self.block_size == 0 _assert(W % self.block_size == 0, '')
num_h_blocks = H // self.block_size num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks num_blocks = num_h_blocks * num_w_blocks

@ -10,6 +10,7 @@ from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert
class NonLocalAttn(nn.Module): class NonLocalAttn(nn.Module):
@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module):
def resize_mat(self, x, t: int): def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape B, C, block_size, block_size1 = x.shape
assert block_size == block_size1 _assert(block_size == block_size1, '')
if t <= 1: if t <= 1:
return x return x
x = x.view(B * C, -1, 1, 1) x = x.view(B * C, -1, 1, 1)
@ -95,7 +96,8 @@ class BilinearAttnTransform(nn.Module):
return x return x
def forward(self, x): def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 _assert(x.shape[-1] % self.block_size == 0, '')
_assert(x.shape[-2] % self.block_size == 0, '')
B, C, H, W = x.shape B, C, H, W = x.shape
out = self.conv1(x) out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) rp = F.adaptive_max_pool2d(out, (self.block_size, 1))

@ -6,7 +6,7 @@ import torch.nn.functional as F
class GroupNorm(nn.GroupNorm): class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
# NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
super().__init__(num_groups, num_channels, eps=eps, affine=affine) super().__init__(num_groups, num_channels, eps=eps, affine=affine)

@ -68,7 +68,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
class GroupNormAct(nn.GroupNorm): class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str): if isinstance(act_layer, str):

@ -9,6 +9,7 @@ from torch import nn as nn
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert
def _kernel_valid(k): def _kernel_valid(k):
@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module):
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.num_paths _assert(x.shape[1] == self.num_paths, '')
x = x.sum(1).mean((2, 3), keepdim=True) x = x.sum(1).mean((2, 3), keepdim=True)
x = self.fc_reduce(x) x = self.fc_reduce(x)
x = self.bn(x) x = self.bn(x)

@ -25,8 +25,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
from .layers import _assert
from .layers import create_conv2d, create_pool2d, to_ntuple from .layers import create_conv2d, create_pool2d, to_ntuple
from .registry import register_model from .registry import register_model
@ -128,8 +130,8 @@ class ConvPool(nn.Module):
""" """
x is expected to have shape (B, C, H, W) x is expected to have shape (B, C, H, W)
""" """
assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
x = self.conv(x) x = self.conv(x)
# Layer norm done over channel dim only # Layer norm done over channel dim only
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -144,8 +146,8 @@ def blockify(x, block_size: int):
block_size (int): edge length of a single square block in units of H, W block_size (int): edge length of a single square block in units of H, W
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
assert H % block_size == 0, '`block_size` must divide input height evenly' _assert(H % block_size == 0, '`block_size` must divide input height evenly')
assert W % block_size == 0, '`block_size` must divide input width evenly' _assert(W % block_size == 0, '`block_size` must divide input width evenly')
grid_height = H // block_size grid_height = H // block_size
grid_width = W // block_size grid_width = W // block_size
x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
@ -153,6 +155,7 @@ def blockify(x, block_size: int):
return x # (B, T, N, C) return x # (B, T, N, C)
@register_notrace_function # reason: int receives Proxy
def deblockify(x, block_size: int): def deblockify(x, block_size: int):
"""blocks to image """blocks to image
Args: Args:

@ -26,6 +26,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x)) return self.conv(self.pool(x))
@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module): class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block. """Normalization-Free pre-activation block.
""" """

@ -15,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier
from .registry import register_model from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -89,10 +89,15 @@ default_cfgs = {
interpolation='bicubic'), interpolation='bicubic'),
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
# ResNets w/ alternative norm layers
'resnet50_gn': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth',
crop_pct=0.94, interpolation='bicubic'),
# ResNeXt # ResNeXt
'resnext50_32x4d': _cfg( 'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth',
interpolation='bicubic'), interpolation='bicubic', crop_pct=0.95),
'resnext50d_32x4d': _cfg( 'resnext50d_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
interpolation='bicubic', interpolation='bicubic',
@ -881,6 +886,14 @@ def wide_resnet101_2(pretrained=False, **kwargs):
return _create_resnet('wide_resnet101_2', pretrained, **model_args) return _create_resnet('wide_resnet101_2', pretrained, **model_args)
@register_model
def resnet50_gn(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model w/ GroupNorm
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args)
@register_model @register_model
def resnext50_32x4d(pretrained=False, **kwargs): def resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model. """Constructs a ResNeXt50-32x4d model.

@ -120,6 +120,13 @@ default_cfgs = {
interpolation='bicubic'), interpolation='bicubic'),
'resnetv2_152d': _cfg( 'resnetv2_152d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'), interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_gn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
} }
@ -639,19 +646,27 @@ def resnetv2_152d(pretrained=False, **kwargs):
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True, **kwargs)
# @register_model # Experimental configs (may change / be removed)
# def resnetv2_50ebd(pretrained=False, **kwargs):
# # FIXME for testing w/ TPU + PyTorch XLA @register_model
# return _create_resnetv2( def resnetv2_50d_gn(pretrained=False, **kwargs):
# 'resnetv2_50d', pretrained=pretrained, return _create_resnetv2(
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, 'resnetv2_50d_gn', pretrained=pretrained,
# stem_type='deep', avg_down=True, **kwargs) layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
# stem_type='deep', avg_down=True, **kwargs)
#
# @register_model
# def resnetv2_50esd(pretrained=False, **kwargs): @register_model
# # FIXME for testing w/ TPU + PyTorch XLA def resnetv2_50d_evob(pretrained=False, **kwargs):
# return _create_resnetv2( return _create_resnetv2(
# 'resnetv2_50d', pretrained=pretrained, 'resnetv2_50d_evob', pretrained=pretrained,
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
# stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
stem_type='deep', avg_down=True, **kwargs)

@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe
Copyright 2020 Ross Wightman Copyright 2020 Ross Wightman
""" """
import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
from math import ceil from math import ceil
@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module):
if self.use_shortcut: if self.use_shortcut:
if self.drop_path is not None: if self.drop_path is not None:
x = self.drop_path(x) x = self.drop_path(x)
x[:, 0:self.in_channels] += shortcut x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
return x return x

@ -21,11 +21,14 @@ import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .layers import _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -100,6 +103,7 @@ def window_partition(x, window_size: int):
return windows return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int): def window_reverse(windows, window_size: int, H: int, W: int):
""" """
Args: Args:
@ -270,7 +274,7 @@ class SwinTransformerBlock(nn.Module):
def forward(self, x): def forward(self, x):
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" _assert(L == H * W, "input feature has wrong size")
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
@ -329,8 +333,8 @@ class PatchMerging(nn.Module):
""" """
H, W = self.input_resolution H, W = self.input_resolution
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size" _assert(L == H * W, "input feature has wrong size")
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C) x = x.view(B, H, W, C)

@ -9,12 +9,12 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple from timm.models.layers.helpers import to_2tuple
from timm.models.layers import _assert
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed from timm.models.vision_transformer import resize_pos_embed
@ -109,7 +109,9 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer # outer
B, N, C = patch_embed.size() B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) patch_embed = torch.cat(
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
dim=1)
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
return pixel_embed, patch_embed return pixel_embed, patch_embed
@ -136,8 +138,10 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos): def forward(self, x, pixel_pos):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \ _assert(H == self.img_size[0],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
_assert(W == self.img_size[1],
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
x = self.proj(x) x = self.proj(x)
x = self.unfold(x) x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])

@ -22,9 +22,10 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .fx_features import register_notrace_module
from .registry import register_model from .registry import register_model
from .vision_transformer import Attention from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -62,6 +63,7 @@ default_cfgs = {
Size_ = Tuple[int, int] Size_ = Tuple[int, int]
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class LocallyGroupedAttn(nn.Module): class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group """ LSA: self attention within a group
""" """

@ -12,7 +12,8 @@ from typing import Union, List, Dict, Any, cast
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct from .fx_features import register_notrace_module
from .layers import ClassifierHead
from .registry import register_model from .registry import register_model
__all__ = [ __all__ = [
@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = {
} }
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class ConvMlp(nn.Module): class ConvMlp(nn.Module):
def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,

@ -88,6 +88,9 @@ default_cfgs = {
url='https://storage.googleapis.com/vit_models/augreg/' url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch32_224': _cfg( 'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k url='', # no official model weights for this combo, only for in21k
), ),
@ -118,6 +121,9 @@ default_cfgs = {
'vit_base_patch16_224_in21k': _cfg( 'vit_base_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843), num_classes=21843),
'vit_base_patch8_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg( 'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843), num_classes=21843),
@ -640,6 +646,16 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_patch8_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_large_patch32_224(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
@ -756,6 +772,18 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).

@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp
from .registry import register_model from .registry import register_model
from .layers import DropPath, trunc_normal_, to_2tuple from .layers import DropPath, trunc_normal_, to_2tuple
from .cait import ClassAttn from .cait import ClassAttn
from .fx_features import register_notrace_module
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -97,6 +98,7 @@ default_cfgs = {
} }
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module): class PositionalEncodingFourier(nn.Module):
""" """
Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.

Loading…
Cancel
Save