From acbd698c83ef020c0f0ca3471e3945bd8611ebe3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 17:49:05 -0700 Subject: [PATCH] Update README.md with updates. Small tweak to head_dist handling. --- README.md | 18 ++++++++++++++++++ timm/models/vision_transformer.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a644d0d0..3f212d7d 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,22 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### April 1, 2021 +* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference +* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) + * Merged distilled variant into main for torchscript compatibility + * Some `timm` cleanup/style tweaks and weights have hub download support +* Cleanup Vision Transformer (ViT) models + * Merge distilled (DeiT) model into main so that torchscript can work + * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) + * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids + * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants + * nn.Sequential for block stack (does not break downstream compat) +* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) +* Add RegNetY-160 weights from DeiT teacher model +* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 +* Some fixes/improvements for TFDS dataset wrapper + ### March 17, 2021 * Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself. * 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224 @@ -189,6 +205,7 @@ A full version of the list below with source links can be found in the [document * NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * PNasNet - https://arxiv.org/abs/1712.00559 +* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 * RepVGG - https://arxiv.org/abs/2101.03697 * ResNet/ResNeXt @@ -204,6 +221,7 @@ A full version of the list below with source links can be found in the [document * ReXNet - https://arxiv.org/abs/2007.00992 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 +* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 * Vision Transformer - https://arxiv.org/abs/2010.11929 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7a7afbff..cd73cc11 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -337,7 +337,7 @@ class VisionTransformer(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - if self.head_dist is not None: + if self.num_tokens == 2: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x):