Improve test crop for ViT models. Small now 77.85, added base weights at 79.35 top-1.

pull/268/head
Ross Wightman 4 years ago
parent d4db9e7977
commit 27a93e9de7

@ -2,6 +2,9 @@
## What's New
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers

@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 1.0, 'interpolation': 'bicubic',
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': '', 'classifier': 'head',
**kwargs
@ -51,7 +51,9 @@ default_cfgs = {
'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
'vit_base_patch16_224': _cfg(),
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth'
),
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
'vit_large_patch16_224': _cfg(),
@ -283,6 +285,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model

Loading…
Cancel
Save