Update README.md with updates. Small tweak to head_dist handling.

pull/533/head
Ross Wightman 4 years ago
parent 9071568f0e
commit acbd698c83

@ -23,6 +23,22 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### April 1, 2021
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
* Merged distilled variant into main for torchscript compatibility
* Some `timm` cleanup/style tweaks and weights have hub download support
* Cleanup Vision Transformer (ViT) models
* Merge distilled (DeiT) model into main so that torchscript can work
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
* nn.Sequential for block stack (does not break downstream compat)
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
* Add RegNetY-160 weights from DeiT teacher model
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
* Some fixes/improvements for TFDS dataset wrapper
### March 17, 2021 ### March 17, 2021
* Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself. * Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
* 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224 * 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
@ -189,6 +205,7 @@ A full version of the list below with source links can be found in the [document
* NFNet-F - https://arxiv.org/abs/2102.06171 * NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559 * PNasNet - https://arxiv.org/abs/1712.00559
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678
* RepVGG - https://arxiv.org/abs/2101.03697 * RepVGG - https://arxiv.org/abs/2101.03697
* ResNet/ResNeXt * ResNet/ResNeXt
@ -204,6 +221,7 @@ A full version of the list below with source links can be found in the [document
* ReXNet - https://arxiv.org/abs/2007.00992 * ReXNet - https://arxiv.org/abs/2007.00992
* SelecSLS - https://arxiv.org/abs/1907.00837 * SelecSLS - https://arxiv.org/abs/1907.00837
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630 * TResNet - https://arxiv.org/abs/2003.13630
* Vision Transformer - https://arxiv.org/abs/2010.11929 * Vision Transformer - https://arxiv.org/abs/2010.11929
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667

@ -337,7 +337,7 @@ class VisionTransformer(nn.Module):
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.head_dist is not None: if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):

Loading…
Cancel
Save