diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 51242619..9d1b1cb4 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -44,6 +44,7 @@ class MaxPool2dSame(nn.MaxPool2d): def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): kernel_size = tup_pair(kernel_size) stride = tup_pair(stride) + dilation = tup_pair(dilation) super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) def forward(self, x): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 4e23eb99..bc802ba2 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -13,7 +13,7 @@ default_cfgs = { 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', 'input_size': (3, 331, 331), 'pool_size': (11, 11), - 'crop_pct': 0.875, + 'crop_pct': 0.911, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index db558401..dc0f078d 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -24,7 +24,7 @@ default_cfgs = { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', 'input_size': (3, 331, 331), 'pool_size': (11, 11), - 'crop_pct': 0.875, + 'crop_pct': 0.911, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), diff --git a/timm/models/resnet.py b/timm/models/resnet.py index e3b5f12f..1781faee 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -521,20 +521,23 @@ class ResNet(nn.Module): def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs): assert isinstance(default_cfg, dict) - load_strict, features = True, False + features = False out_indices = None if kwargs.pop('features_only', False): - load_strict, features = False, True + features = True kwargs.pop('num_classes', 0) out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4)) + pruned = kwargs.pop('pruned', False) + model = ResNet(**kwargs) model.default_cfg = copy.deepcopy(default_cfg) - if kwargs.pop('pruned', False): + + if pruned: model = adapt_model_from_file(model, variant) if pretrained: load_pretrained( model, - num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict) + num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features) if features: model = FeatureNet(model, out_indices=out_indices) return model