Add seresnet26_32x4d cfg and weights + interpolation str->PIL enum fn

pull/1/head
Ross Wightman 6 years ago
parent 71afec86d3
commit 9e296dbffb

@ -138,6 +138,18 @@ class ToTensor:
return torch.from_numpy(np_img).to(dtype=self.dtype)
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
def transforms_imagenet_train(
img_size=224,
scale=(0.1, 1.0),
@ -152,7 +164,7 @@ def transforms_imagenet_train(
tfl = [
transforms.RandomResizedCrop(
img_size, scale=scale,
interpolation=Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
interpolation=_pil_interp(interpolation)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter),
]
@ -192,7 +204,7 @@ def transforms_imagenet_eval(
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:

@ -47,6 +47,9 @@ default_cfgs = {
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
'seresnet152':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
'seresnext26_32x4d':
_cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1',
interpolation='bicubic'),
'seresnext50_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
'seresnext101_32x4d':

Loading…
Cancel
Save