diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1136f306..9e0a4aac 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,9 +16,9 @@ jobs: strategy: matrix: os: [ubuntu-latest, macOS-latest] - python: ['3.8'] - torch: ['1.9.0'] - torchvision: ['0.10.0'] + python: ['3.9'] + torch: ['1.10.0'] + torchvision: ['0.11.1'] runs-on: ${{ matrix.os }} steps: @@ -30,7 +30,7 @@ jobs: - name: Install testing dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-timeout + pip install pytest pytest-timeout expecttest - name: Install torch on mac if: startsWith(matrix.os, 'macOS') run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} diff --git a/tests/test_models.py b/tests/test_models.py index c0d0e901..f55247ee 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,9 +4,16 @@ import platform import os import fnmatch +try: + from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value +from timm.models.fx_features import _leaf_modules, _autowrap_functions if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -38,6 +45,10 @@ TARGET_JIT_SIZE = 128 MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 +TARGET_FWD_FX_SIZE = 128 +MAX_FWD_FX_SIZE = 224 +TARGET_BWD_FX_SIZE = 128 +MAX_BWD_FX_SIZE = 224 def _get_input_size(model=None, model_name='', target=None): @@ -297,3 +308,135 @@ def test_model_forward_features(model_name, batch_size): assert e == o.shape[1] assert o.shape[0] == batch_size assert not torch.isnan(o).any() + + +def _create_fx_model(model, train=False): + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + eval_return_nodes = [eval_nodes[-1]] + train_return_nodes = [train_nodes[-1]] + if train: + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + return fx_model + + +EXCLUDE_FX_FILTERS = [] +# not enough memory to run fx on more models than other tests +if 'GITHUB_ACTIONS' in os.environ: + EXCLUDE_FX_FILTERS += [ + 'beit_large*', + 'swin_large*', + '*resnext101_32x32d', + 'resnetv2_152x2*', + ] + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx(model_name, batch_size): + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) + if max(input_size) > MAX_FWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) + + assert torch.all(fx_outputs == outputs) + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) + if max(input_size) > MAX_BWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") + + model = create_model(model_name, pretrained=False, num_classes=42) + num_params = sum([x.numel() for x in model.parameters()]) + model.train() + + model = _create_fx_model(model, train=True) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + +# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow +EXCLUDE_FX_JIT_FILTERS = [ + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', +] + EXCLUDE_FX_FILTERS + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + 'model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() + + model = torch.jit.script(_create_fx_model(model)) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index 03b03cf5..e86bcc29 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,11 @@ import os -from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ - Places365, ImageNet, ImageFolder +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder +try: + from torchvision.datasets import Places365 + has_places365 = True +except ImportError: + has_places365 = False try: from torchvision.datasets import INaturalist has_inaturalist = True @@ -104,6 +108,7 @@ def create_dataset( split = '2021_valid' ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) elif name == 'places365': + assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' if split in _TRAIN_SYNONYM: split = 'train-standard' elif split in _EVAL_SYNONYM: diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index dd24c55c..990d786b 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -36,7 +36,7 @@ PREFETCH_SIZE = 2048 # examples to prefetch def even_split_indices(split, n, num_examples): partitions = [round(i * num_examples / n) for i in range(n + 1)] - return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] + return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)] def get_class_labels(info): @@ -70,6 +70,7 @@ class ParserTfds(Parser): components. """ + def __init__( self, root, @@ -99,6 +100,7 @@ class ParserTfds(Parser): download: download and build TFDS dataset if set, otherwise must use tfds CLI repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) seed: common seed for shard shuffle across all distributed/worker instances + input_name: name of Feature to return as data (input) input_image: image mode if input is an image (currently PIL mode string) target_name: name of Feature to return as target (label) target_image: image mode if target is an image (currently PIL mode string) @@ -111,7 +113,7 @@ class ParserTfds(Parser): self.split = split self.is_training = is_training if self.is_training: - assert batch_size is not None,\ + assert batch_size is not None, \ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats @@ -184,7 +186,7 @@ class ParserTfds(Parser): InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. - + I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing the data across workers. For training InputContext is used to assign shards to nodes unless num_shards in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or diff --git a/timm/models/beit.py b/timm/models/beit.py index 199c2a4b..f644b657 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -86,9 +86,11 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None if window_size: @@ -127,13 +129,7 @@ class Attention(nn.Module): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - if torch.jit.is_scripting(): - # FIXME requires_grad breaks w/ torchscript - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias)) - else: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index b05cd91a..7fc7f82e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -36,6 +36,9 @@ default_cfgs = { 'botnet26t_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'sebotnet33ts_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth', + fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'botnet50ts_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), @@ -51,7 +54,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h_256-c6d7ff15.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth', @@ -97,6 +100,22 @@ model_cfgs = dict( self_attn_layer='bottleneck', self_attn_kwargs=dict() ), + sebotnet33ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + act_layer='silu', + num_features=1280, + attn_layer='se', + self_attn_layer='bottleneck', + self_attn_kwargs=dict() + ), botnet50ts=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), @@ -322,6 +341,13 @@ def botnet26t_256(pretrained=False, **kwargs): return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) +@register_model +def sebotnet33ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, + """ + return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs) + + @register_model def botnet50ts_256(pretrained=False, **kwargs): """ Bottleneck Transformer w/ ResNet50-T backbone, silu act. diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index d7253bdf..fa57943a 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -35,7 +35,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d from .registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -136,20 +136,26 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'), # experimental models, likely to change ot be removed - 'regnetz_b': _cfgr( + 'regnetz_b16': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), - 'regnetz_c': _cfgr( + 'regnetz_c16': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), - 'regnetz_d': _cfgr( + 'regnetz_d32': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), 'regnetz_d8': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), + 'regnetz_e8': _cfgr( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), + 'regnetz_d8_evob': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), - 'regnetz_e8': _cfgr( + 'regnetz_d8_evos': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), } @@ -506,7 +512,7 @@ model_cfgs = dict( ), # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW - regnetz_b=ByoModelCfg( + regnetz_b16=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), @@ -522,7 +528,7 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), - regnetz_c=ByoModelCfg( + regnetz_c16=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4), ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4), @@ -538,7 +544,7 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), - regnetz_d=ByoModelCfg( + regnetz_d32=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4), ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4), @@ -589,8 +595,45 @@ model_cfgs = dict( attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), ), -) + # experimental EvoNorm configs + regnetz_d8_evob=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + downsample='', + num_features=1792, + act_layer='silu', + norm_layer='evonormbatch', + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + ), + regnetz_d8_evos=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), + ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4), + ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4), + ), + stem_chs=64, + stem_type='deep', + stem_pool='', + downsample='', + num_features=1792, + act_layer='silu', + norm_layer=partial(EvoNormSample2d, groups=32), + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + ), +) @register_model def gernet_l(pretrained=False, **kwargs): @@ -779,24 +822,24 @@ def gcresnext50ts(pretrained=False, **kwargs): @register_model -def regnetz_b(pretrained=False, **kwargs): +def regnetz_b16(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs) @register_model -def regnetz_c(pretrained=False, **kwargs): +def regnetz_c16(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs) @register_model -def regnetz_d(pretrained=False, **kwargs): +def regnetz_d32(pretrained=False, **kwargs): """ """ - return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs) + return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs) @register_model @@ -813,6 +856,20 @@ def regnetz_e8(pretrained=False, **kwargs): return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs) +@register_model +def regnetz_d8_evob(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs) + + +@register_model +def regnetz_d8_evos(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs) + + def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) diff --git a/timm/models/coat.py b/timm/models/coat.py index f071715a..18ff8ab9 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -19,6 +19,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model +from .layers import _assert __all__ = [ @@ -105,7 +106,7 @@ class ConvRelPosEnc(nn.Module): def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size - assert N == 1 + H * W + _assert(N == 1 + H * W, '') # Convolutional relative position encoding. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] @@ -177,7 +178,7 @@ class ConvPosEnc(nn.Module): def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size - assert N == 1 + H * W + _assert(N == 1 + H * W, '') # Extract CLS token and image tokens. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] @@ -275,7 +276,7 @@ class ParallelBlock(nn.Module): """ Feature map interpolation. """ B, N, C = x.shape H, W = size - assert N == 1 + H * W + _assert(N == 1 + H * W, '') cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] diff --git a/timm/models/convit.py b/timm/models/convit.py index f58249ec..6ef1da72 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -30,6 +30,7 @@ from .helpers import build_model_with_cfg from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp from .registry import register_model from .vision_transformer_hybrid import HybridEmbed +from .fx_features import register_notrace_module import torch import torch.nn as nn @@ -56,6 +57,7 @@ default_cfgs = { } +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class GPSA(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 6e0160f9..ddc4f64c 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -22,6 +22,7 @@ NOTE: model names have been renamed from originals to represent actual input res Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py """ +from typing import Tuple import torch import torch.nn as nn @@ -31,8 +32,9 @@ from functools import partial from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg -from .layers import DropPath, to_2tuple, trunc_normal_ +from .layers import DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model from .vision_transformer import Mlp, Block @@ -116,8 +118,10 @@ class PatchEmbed(nn.Module): def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x).flatten(2).transpose(1, 2) return x @@ -255,6 +259,27 @@ def _compute_num_patches(img_size, patches): return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] +@register_notrace_function +def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript + """ + Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. + Args: + x (Tensor): input image + ss (tuple[int, int]): height and width to scale to + crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False + Returns: + Tensor: the "scaled" image batch tensor + """ + H, W = x.shape[-2:] + if H != ss[0] or W != ss[1]: + if crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) + return x + + class CrossViT(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -342,17 +367,12 @@ class CrossViT(nn.Module): range(self.num_branches)]) def forward_features(self, x): - B, C, H, W = x.shape + B = x.shape[0] xs = [] for i, patch_embed in enumerate(self.patch_embed): x_ = x ss = self.img_size_scaled[i] - if H != ss[0] or W != ss[1]: - if self.crop_scale and ss[0] <= H and ss[1] <= W: - cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) - x_ = x_[:, :, cu:cu + ss[0], cl:cl + ss[1]] - else: - x_ = torch.nn.functional.interpolate(x_, size=ss, mode='bicubic', align_corners=False) + x_ = scale_image(x_, ss, self.crop_scale) x_ = patch_embed(x_) cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = cls_tokens.expand(B, -1, -1) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 00000000..5a25ee3e --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,73 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable +from torch import nn + +from .features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath +from .layers.non_local_attn import BilinearAttnTransform +from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BatchNormAct2d, # reason: flow control for jit scripting + BilinearAttnTransform, # reason: flow control t <= 1 + BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + DropPath, # reason: TypeError: rand recieved Proxy in `size` argument +} + +try: + from .layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_notrace_module(module: nn.Module): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +class FeatureGraphNet(nn.Module): + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + def forward(self, x): + return list(self.graph_module(x).values()) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 9dafeefa..3ea8c8b7 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -14,6 +14,7 @@ import torch.nn as nn from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .layers import Conv2dSame, Linear @@ -477,6 +478,8 @@ def build_model_with_cfg( feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet + elif feature_cls == 'fx': + feature_cls = FeatureGraphNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index f55fd989..c3db464e 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from .helpers import to_2tuple, make_divisible from .weight_init import trunc_normal_ +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -133,8 +134,8 @@ class BottleneckAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H == self.pos_embed.height - assert W == self.pos_embed.width + _assert(H == self.pos_embed.height, '') + _assert(W == self.pos_embed.width, '') x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W @@ -154,5 +155,3 @@ class BottleneckAttn(nn.Module): out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W out = self.pool(out) return out - - diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 9023afd0..8c08e49f 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -12,6 +12,8 @@ Hacked together by / Copyright 2020 Ross Wightman import torch import torch.nn as nn +from .trace_utils import _assert + class EvoNormBatch2d(nn.Module): def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): @@ -19,12 +21,10 @@ class EvoNormBatch2d(nn.Module): self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): @@ -36,33 +36,32 @@ class EvoNormBatch2d(nn.Module): def forward(self, x): assert x.dim() == 4, 'expected 4D input' x_type = x.dtype + running_var = self.running_var.view(1, -1, 1, 1) if self.training: var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) n = x.numel() / x.shape[1] - self.running_var.copy_( - var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) else: - var = self.running_var + var = running_var - if self.apply_act: - v = self.v.to(dtype=x_type) + if self.v is not None: + v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = d.max((var + self.eps).sqrt().to(dtype=x_type)) x = x / d - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) class EvoNormSample2d(nn.Module): - def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): + def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None): super(EvoNormSample2d, self).__init__() self.apply_act = apply_act # apply activation (non-linearity) self.groups = groups self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None self.reset_parameters() def reset_parameters(self): @@ -72,12 +71,12 @@ class EvoNormSample2d(nn.Module): nn.init.ones_(self.v) def forward(self, x): - assert x.dim() == 4, 'expected 4D input' + _assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape - assert C % self.groups == 0 - if self.apply_act: - n = x * (x * self.v).sigmoid() + _assert(C % self.groups == 0, '') + if self.v is not None: + n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() x = x.reshape(B, self.groups, -1) x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x.reshape(B, C, H, W) - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 4149e812..f2ac64f8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -16,7 +16,7 @@ The attention mechanism works but it's slow as implemented. Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Tuple, List +from typing import List import torch from torch import nn @@ -24,6 +24,7 @@ import torch.nn.functional as F from .helpers import make_divisible from .weight_init import trunc_normal_ +from .trace_utils import _assert def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -167,8 +168,8 @@ class HaloAttn(nn.Module): def forward(self, x): B, C, H, W = x.shape - assert H % self.block_size == 0 - assert W % self.block_size == 0 + _assert(H % self.block_size == 0, '') + _assert(W % self.block_size == 0, '') num_h_blocks = H // self.block_size num_w_blocks = W // self.block_size num_blocks = num_h_blocks * num_w_blocks diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index a537d60e..881fa36d 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -10,6 +10,7 @@ from torch.nn import functional as F from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from .trace_utils import _assert class NonLocalAttn(nn.Module): @@ -83,7 +84,7 @@ class BilinearAttnTransform(nn.Module): def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape - assert block_size == block_size1 + _assert(block_size == block_size1, '') if t <= 1: return x x = x.view(B * C, -1, 1, 1) @@ -95,7 +96,8 @@ class BilinearAttnTransform(nn.Module): return x def forward(self, x): - assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + _assert(x.shape[-1] % self.block_size == 0, '') + _assert(x.shape[-2] % self.block_size == 0, '') B, C, H, W = x.shape out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index aace107b..85297420 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -6,7 +6,7 @@ import torch.nn.functional as F class GroupNorm(nn.GroupNorm): - def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN super().__init__(num_groups, num_channels, eps=eps, affine=affine) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 02cabe88..2e15181f 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -68,7 +68,7 @@ class BatchNormAct2d(nn.BatchNorm2d): class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args - def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, + def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) if isinstance(act_layer, str): diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index f28b8d2e..1aeb9294 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -9,6 +9,7 @@ from torch import nn as nn from .conv_bn_act import ConvBnAct from .helpers import make_divisible +from .trace_utils import _assert def _kernel_valid(k): @@ -34,7 +35,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): - assert x.shape[1] == self.num_paths + _assert(x.shape[1] == self.num_paths, '') x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) diff --git a/timm/models/nest.py b/timm/models/nest.py index 9a477bf9..22cf6099 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -25,8 +25,10 @@ import torch.nn.functional as F from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers import _assert from .layers import create_conv2d, create_pool2d, to_ntuple from .registry import register_model @@ -128,8 +130,8 @@ class ConvPool(nn.Module): """ x is expected to have shape (B, C, H, W) """ - assert x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims' - assert x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims' + _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') x = self.conv(x) # Layer norm done over channel dim only x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -144,8 +146,8 @@ def blockify(x, block_size: int): block_size (int): edge length of a single square block in units of H, W """ B, H, W, C = x.shape - assert H % block_size == 0, '`block_size` must divide input height evenly' - assert W % block_size == 0, '`block_size` must divide input width evenly' + _assert(H % block_size == 0, '`block_size` must divide input height evenly') + _assert(W % block_size == 0, '`block_size` must divide input width evenly') grid_height = H // block_size grid_width = W // block_size x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) @@ -153,6 +155,7 @@ def blockify(x, block_size: int): return x # (B, T, N, C) +@register_notrace_function # reason: int receives Proxy def deblockify(x, block_size: int): """blocks to image Args: diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..973cbd66 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """ diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1c7cbba2..bbcae9a3 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier from .registry import register_model __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this @@ -89,10 +89,15 @@ default_cfgs = { interpolation='bicubic'), 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + # ResNets w/ alternative norm layers + 'resnet50_gn': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth', + crop_pct=0.94, interpolation='bicubic'), + # ResNeXt 'resnext50_32x4d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', - interpolation='bicubic'), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth', + interpolation='bicubic', crop_pct=0.95), 'resnext50d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', interpolation='bicubic', @@ -881,6 +886,14 @@ def wide_resnet101_2(pretrained=False, **kwargs): return _create_resnet('wide_resnet101_2', pretrained, **model_args) +@register_model +def resnet50_gn(pretrained=False, **kwargs): + """Constructs a ResNet-50 model w/ GroupNorm + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args) + + @register_model def resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model. diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 43940cc3..e38eaf5e 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -120,6 +120,13 @@ default_cfgs = { interpolation='bicubic'), 'resnetv2_152d': _cfg( interpolation='bicubic', first_conv='stem.conv1'), + + 'resnetv2_50d_gn': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evob': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evos': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), } @@ -639,19 +646,27 @@ def resnetv2_152d(pretrained=False, **kwargs): stem_type='deep', avg_down=True, **kwargs) -# @register_model -# def resnetv2_50ebd(pretrained=False, **kwargs): -# # FIXME for testing w/ TPU + PyTorch XLA -# return _create_resnetv2( -# 'resnetv2_50d', pretrained=pretrained, -# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, -# stem_type='deep', avg_down=True, **kwargs) -# -# -# @register_model -# def resnetv2_50esd(pretrained=False, **kwargs): -# # FIXME for testing w/ TPU + PyTorch XLA -# return _create_resnetv2( -# 'resnetv2_50d', pretrained=pretrained, -# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, -# stem_type='deep', avg_down=True, **kwargs) +# Experimental configs (may change / be removed) + +@register_model +def resnetv2_50d_gn(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_gn', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evob(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evob', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evos(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evos', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, + stem_type='deep', avg_down=True, **kwargs) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 279780be..f27ce5d8 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -10,6 +10,7 @@ Changes for timm, feature extraction, and rounded channel variant hacked togethe Copyright 2020 Ross Wightman """ +import torch import torch.nn as nn from functools import partial from math import ceil @@ -92,7 +93,7 @@ class LinearBottleneck(nn.Module): if self.use_shortcut: if self.drop_path is not None: x = self.drop_path(x) - x[:, 0:self.in_channels] += shortcut + x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1) return x diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 822aeef8..92057902 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -21,11 +21,14 @@ import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .layers import _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + _logger = logging.getLogger(__name__) @@ -100,6 +103,7 @@ def window_partition(x, window_size: int): return windows +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size: int, H: int, W: int): """ Args: @@ -270,7 +274,7 @@ class SwinTransformerBlock(nn.Module): def forward(self, x): H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" + _assert(L == H * W, "input feature has wrong size") shortcut = x x = self.norm1(x) @@ -329,8 +333,8 @@ class PatchMerging(nn.Module): """ H, W = self.input_resolution B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") x = x.view(B, H, W, C) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 9829653c..d52f9ce6 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -9,12 +9,12 @@ https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT import math import torch import torch.nn as nn -from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple +from timm.models.layers import _assert from timm.models.registry import register_model from timm.models.vision_transformer import resize_pos_embed @@ -109,7 +109,9 @@ class Block(nn.Module): pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) # outer B, N, C = patch_embed.size() - patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = torch.cat( + [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))], + dim=1) patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) return pixel_embed, patch_embed @@ -136,8 +138,10 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") x = self.proj(x) x = self.unfold(x) x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) diff --git a/timm/models/twins.py b/timm/models/twins.py index 4aed09d9..67a939d4 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -22,9 +22,10 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .fx_features import register_notrace_module from .registry import register_model from .vision_transformer import Attention -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg def _cfg(url='', **kwargs): @@ -62,6 +63,7 @@ default_cfgs = { Size_ = Tuple[int, int] +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 8bea03e7..11f6d0ea 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -12,7 +12,8 @@ from typing import Union, List, Dict, Any, cast from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct +from .fx_features import register_notrace_module +from .layers import ClassifierHead from .registry import register_model __all__ = [ @@ -52,6 +53,7 @@ cfgs: Dict[str, List[Union[str, int]]] = { } +@register_notrace_module # reason: FX can't symbolically trace control flow in forward method class ConvMlp(nn.Module): def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 94ae2666..6e568abf 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -88,6 +88,9 @@ default_cfgs = { url='https://storage.googleapis.com/vit_models/augreg/' 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k ), @@ -118,6 +121,9 @@ default_cfgs = { 'vit_base_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', num_classes=21843), + 'vit_base_patch8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_large_patch32_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843), @@ -640,6 +646,16 @@ def vit_base_patch16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch8_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. @@ -756,6 +772,18 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch8_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_large_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 2942ed8a..ac5e802c 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -21,6 +21,7 @@ from .vision_transformer import _cfg, Mlp from .registry import register_model from .layers import DropPath, trunc_normal_, to_2tuple from .cait import ClassAttn +from .fx_features import register_notrace_module def _cfg(url='', **kwargs): @@ -97,6 +98,7 @@ default_cfgs = { } +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method class PositionalEncodingFourier(nn.Module): """ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.