Merge pull request #714 from rwightman/vit_and_bit_test_fixes

Fix a few issues loading pretrained vit/bit npz weights...
pull/731/head
Ross Wightman 4 years ago committed by GitHub
commit 7606bdf9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### June 23, 2021
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050).
### June 20, 2021 ### June 20, 2021
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) * Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg) * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)

@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg # check classifier name matches default_cfg
classifier = cfg['classifier'] classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)): if not isinstance(classifier, (tuple, list)):
@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
assert len(outputs.shape) == 2 assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features assert outputs.shape[1] == model.num_features
model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
# check classifier name matches default_cfg # check classifier name matches default_cfg
classifier = cfg['classifier'] classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)): if not isinstance(classifier, (tuple, list)):
@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" """Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5) create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

@ -182,7 +182,7 @@ class GhostNet(nn.Module):
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
self.act2 = nn.ReLU(inplace=True) self.act2 = nn.ReLU(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(out_chs, num_classes) self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier

@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model):
state_dict = state_dict['model'] state_dict = state_dict['model']
D = model.state_dict() D = model.state_dict()
for k in state_dict.keys(): for k in state_dict.keys():
if D[k].ndim == 4 and state_dict[k].ndim == 2: if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
state_dict[k] = state_dict[k][:, :, None, None] state_dict[k] = state_dict[k][:, :, None, None]
return state_dict return state_dict

@ -129,7 +129,9 @@ default_cfgs = dict(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(), gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(), gmlp_s16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
),
gmlp_b16_224=_cfg(), gmlp_b16_224=_cfg(),
) )
@ -266,7 +268,7 @@ class MlpMixer(nn.Module):
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)]) for _ in range(num_blocks)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.num_classes) # zero init self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(nlhb=nlhb) self.init_weights(nlhb=nlhb)

@ -424,7 +424,8 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/
model.stem.conv.weight.copy_(stem_conv_w) model.stem.conv.weight.copy_(stem_conv_w)
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(model.stages.named_children()): for i, (sname, stage) in enumerate(model.stages.named_children()):

@ -237,7 +237,6 @@ class Visformer(nn.Module):
self.num_features = embed_dim if self.vit_stem else embed_dim * 2 self.num_features = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.head = nn.Linear(self.num_features, num_classes)
# weights init # weights init
if self.pos_embed: if self.pos_embed:

@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in:
- https://arxiv.org/abs/2010.11929 - https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO - https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer The official jax code is released and available at https://github.com/google-research/vision_transformer
@ -448,9 +448,12 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model.pos_embed.copy_(pos_embed_w) model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()): for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16). """ ViT-Tiny (Vit-Ti/16).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
def vit_small_patch32_224_in21k(pretrained=False, **kwargs): def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
def vit_small_patch16_224_in21k(pretrained=False, **kwargs): def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
def vit_base_patch32_224_in21k(pretrained=False, **kwargs): def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: converted weights not currently available, too large for github release hosting. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)

Loading…
Cancel
Save