diff --git a/data/transforms.py b/data/transforms.py index 5fe7da60..e7dcee2f 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -138,6 +138,18 @@ class ToTensor: return torch.from_numpy(np_img).to(dtype=self.dtype) +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + # default bilinear, do we want to allow nearest? + return Image.BILINEAR + + def transforms_imagenet_train( img_size=224, scale=(0.1, 1.0), @@ -152,7 +164,7 @@ def transforms_imagenet_train( tfl = [ transforms.RandomResizedCrop( img_size, scale=scale, - interpolation=Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC), + interpolation=_pil_interp(interpolation)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(*color_jitter), ] @@ -192,7 +204,7 @@ def transforms_imagenet_eval( scale_size = int(math.floor(img_size / crop_pct)) tfl = [ - transforms.Resize(scale_size, Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC), + transforms.Resize(scale_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size), ] if use_prefetcher: diff --git a/models/senet.py b/models/senet.py index 7676ae99..d3465be9 100644 --- a/models/senet.py +++ b/models/senet.py @@ -47,6 +47,9 @@ default_cfgs = { _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'), 'seresnet152': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'), + 'seresnext26_32x4d': + _cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1', + interpolation='bicubic'), 'seresnext50_32x4d': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), 'seresnext101_32x4d':