Add interpolation mode handling to transforms. Removes InterpolationMode warning. Works for torchvision versions w/ and w/o InterpolationMode enum. Fix #738.

more_datasets
Ross Wightman 3 years ago
parent ed41d32637
commit a41de1f666

@ -1,5 +1,10 @@
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
try:
from torchvision.transforms.functional import InterpolationMode
has_interpolation_mode = True
except ImportError:
has_interpolation_mode = False
from PIL import Image from PIL import Image
import warnings import warnings
import math import math
@ -31,28 +36,50 @@ class ToTensor:
_pil_interpolation_to_str = { _pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST', Image.NEAREST: 'nearest',
Image.BILINEAR: 'PIL.Image.BILINEAR', Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'PIL.Image.BICUBIC', Image.BICUBIC: 'bicubic',
Image.LANCZOS: 'PIL.Image.LANCZOS', Image.BOX: 'box',
Image.HAMMING: 'PIL.Image.HAMMING', Image.HAMMING: 'hamming',
Image.BOX: 'PIL.Image.BOX', Image.LANCZOS: 'lanczos',
} }
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
def _pil_interp(method): if has_interpolation_mode:
if method == 'bicubic': _torch_interpolation_to_str = {
return Image.BICUBIC InterpolationMode.NEAREST: 'nearest',
elif method == 'lanczos': InterpolationMode.BILINEAR: 'bilinear',
return Image.LANCZOS InterpolationMode.BICUBIC: 'bicubic',
elif method == 'hamming': InterpolationMode.BOX: 'box',
return Image.HAMMING InterpolationMode.HAMMING: 'hamming',
InterpolationMode.LANCZOS: 'lanczos',
}
_str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
else:
_pil_interpolation_to_torch = {}
_torch_interpolation_to_str = {}
def str_to_pil_interp(mode_str):
return _str_to_pil_interpolation[mode_str]
def str_to_interp_mode(mode_str):
if has_interpolation_mode:
return _str_to_torch_interpolation[mode_str]
else:
return _str_to_pil_interpolation[mode_str]
def interp_mode_to_str(mode):
if has_interpolation_mode:
return _torch_interpolation_to_str[mode]
else: else:
# default bilinear, do we want to allow nearest? return _pil_interpolation_to_str[mode]
return Image.BILINEAR
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) _RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
class RandomResizedCropAndInterpolation: class RandomResizedCropAndInterpolation:
@ -82,7 +109,7 @@ class RandomResizedCropAndInterpolation:
if interpolation == 'random': if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION self.interpolation = _RANDOM_INTERPOLATION
else: else:
self.interpolation = _pil_interp(interpolation) self.interpolation = str_to_interp_mode(interpolation)
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@ -146,9 +173,9 @@ class RandomResizedCropAndInterpolation:
def __repr__(self): def __repr__(self):
if isinstance(self.interpolation, (tuple, list)): if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
else: else:
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))

@ -10,7 +10,7 @@ from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy
from timm.data.random_erasing import RandomErasing from timm.data.random_erasing import RandomErasing
@ -25,7 +25,7 @@ def transforms_noaug_train(
# random interpolation not supported with no-aug # random interpolation not supported with no-aug
interpolation = 'bilinear' interpolation = 'bilinear'
tfl = [ tfl = [
transforms.Resize(img_size, _pil_interp(interpolation)), transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size) transforms.CenterCrop(img_size)
] ]
if use_prefetcher: if use_prefetcher:
@ -87,7 +87,7 @@ def transforms_imagenet_train(
img_mean=tuple([min(255, round(255 * x)) for x in mean]), img_mean=tuple([min(255, round(255 * x)) for x in mean]),
) )
if interpolation and interpolation != 'random': if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation) aa_params['interpolation'] = str_to_pil_interp(interpolation)
if auto_augment.startswith('rand'): if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'): elif auto_augment.startswith('augmix'):
@ -147,7 +147,7 @@ def transforms_imagenet_eval(
scale_size = int(math.floor(img_size / crop_pct)) scale_size = int(math.floor(img_size / crop_pct))
tfl = [ tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)), transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.CenterCrop(img_size), transforms.CenterCrop(img_size),
] ]
if use_prefetcher: if use_prefetcher:

Loading…
Cancel
Save