Update ResNet50 weights to AuxMix trained 78.994 top-1. A few commentes re 'tiered_narrow' tn variant.

pull/82/head
Ross Wightman 5 years ago
parent cc0b1f4130
commit 2f41905ba5

@ -3,7 +3,7 @@
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool. additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
""" """
import math import math
@ -42,7 +42,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnet50': _cfg( 'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_am-6c502b37.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnet50d': _cfg( 'resnet50d': _cfg(
url='', url='',
@ -243,13 +243,14 @@ class ResNet(nn.Module):
variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
ResNet variants: ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
* normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
* c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
* d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
* e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
* s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
* t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
* tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
ResNeXt ResNeXt
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
@ -285,6 +286,7 @@ class ResNet(nn.Module):
* '', default - a single 7x7 conv with a width of stem_width * '', default - a single 7x7 conv with a width of stem_width
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2 * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2
* 'deep_tiered_narrow' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
block_reduce_first: int, default 1 block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks, Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2 1 for all archs except senets, where 2

Loading…
Cancel
Save