diff --git a/README.md b/README.md index 6b41d772..07c71a76 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### June 23, 2021 +* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). + ### June 20, 2021 * Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg) diff --git a/tests/test_models.py b/tests/test_models.py index 0a770784..5c8b02db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size): # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + if 'pruned' not in model_name: # FIXME better pruned model handling + # test classifier + global pool deletion via __init__ + model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 4 + if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet): + # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + # check classifier name matches default_cfg classifier = cfg['classifier'] if not isinstance(classifier, (tuple, list)): @@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size): assert len(outputs.shape) == 2 assert outputs.shape[1] == model.num_features + model = create_model(model_name, pretrained=False, num_classes=0).eval() + outputs = model.forward(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert len(outputs.shape) == 2 + assert outputs.shape[1] == model.num_features + # check classifier name matches default_cfg classifier = cfg['classifier'] if not isinstance(classifier, (tuple, list)): @@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ: """Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5) + create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0) @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index a73047c5..3b6f90a4 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -182,7 +182,7 @@ class GhostNet(nn.Module): self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(out_chs, num_classes) + self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() def get_classifier(self): return self.classifier diff --git a/timm/models/levit.py b/timm/models/levit.py index fa35f41f..9987e4ba 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model): state_dict = state_dict['model'] D = model.state_dict() for k in state_dict.keys(): - if D[k].ndim == 4 and state_dict[k].ndim == 2: + if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: state_dict[k] = state_dict[k][:, :, None, None] return state_dict diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 7a87eb36..f128b9c9 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -129,7 +129,9 @@ default_cfgs = dict( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), gmlp_ti16_224=_cfg(), - gmlp_s16_224=_cfg(), + gmlp_s16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth', + ), gmlp_b16_224=_cfg(), ) @@ -266,7 +268,7 @@ class MlpMixer(nn.Module): act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) self.norm = norm_layer(embed_dim) - self.head = nn.Linear(embed_dim, self.num_classes) # zero init + self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.init_weights(nlhb=nlhb) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index a3c89532..b96d7742 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -424,7 +424,8 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/ model.stem.conv.weight.copy_(stem_conv_w) model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) - if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \ + model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) for i, (sname, stage) in enumerate(model.stages.named_children()): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 16631027..7740f381 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -237,7 +237,6 @@ class Visformer(nn.Module): self.num_features = embed_dim if self.vit_stem else embed_dim * 2 self.norm = norm_layer(self.num_features) self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - self.head = nn.Linear(self.num_features, num_classes) # weights init if self.pos_embed: diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 89fba7de..9ec45868 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in: - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.TODO + - https://arxiv.org/abs/2106.10270 The official jax code is released and available at https://github.com/google-research/vision_transformer @@ -448,9 +448,12 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' @@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs): def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) @@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): def vit_small_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) @@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs): def vit_small_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) @@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_large_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ model_kwargs = dict( patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) @@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_large_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: converted weights not currently available, too large for github release hosting. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)