Fix a few issues that came up in tests

pull/175/head
Ross Wightman 4 years ago
parent d23a2697d0
commit d0113f9cdb

@ -44,6 +44,7 @@ class MaxPool2dSame(nn.MaxPool2d):
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
dilation = tup_pair(dilation)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
def forward(self, x):

@ -13,7 +13,7 @@ default_cfgs = {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
'input_size': (3, 331, 331),
'pool_size': (11, 11),
'crop_pct': 0.875,
'crop_pct': 0.911,
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),

@ -24,7 +24,7 @@ default_cfgs = {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
'input_size': (3, 331, 331),
'pool_size': (11, 11),
'crop_pct': 0.875,
'crop_pct': 0.911,
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),

@ -521,20 +521,23 @@ class ResNet(nn.Module):
def _create_resnet_with_cfg(variant, default_cfg, pretrained=False, **kwargs):
assert isinstance(default_cfg, dict)
load_strict, features = True, False
features = False
out_indices = None
if kwargs.pop('features_only', False):
load_strict, features = False, True
features = True
kwargs.pop('num_classes', 0)
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4))
pruned = kwargs.pop('pruned', False)
model = ResNet(**kwargs)
model.default_cfg = copy.deepcopy(default_cfg)
if kwargs.pop('pruned', False):
if pruned:
model = adapt_model_from_file(model, variant)
if pretrained:
load_pretrained(
model,
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=load_strict)
num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), strict=not features)
if features:
model = FeatureNet(model, out_indices=out_indices)
return model

Loading…
Cancel
Save