From a41de1f666f9187e70845bbcf5b092f40acaf097 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 28 Oct 2021 17:35:01 -0700 Subject: [PATCH] Add interpolation mode handling to transforms. Removes InterpolationMode warning. Works for torchvision versions w/ and w/o InterpolationMode enum. Fix #738. --- timm/data/transforms.py | 65 +++++++++++++++++++++++---------- timm/data/transforms_factory.py | 8 ++-- 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 4220304f..45c078f3 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,5 +1,10 @@ import torch import torchvision.transforms.functional as F +try: + from torchvision.transforms.functional import InterpolationMode + has_interpolation_mode = True +except ImportError: + has_interpolation_mode = False from PIL import Image import warnings import math @@ -31,28 +36,50 @@ class ToTensor: _pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', + Image.NEAREST: 'nearest', + Image.BILINEAR: 'bilinear', + Image.BICUBIC: 'bicubic', + Image.BOX: 'box', + Image.HAMMING: 'hamming', + Image.LANCZOS: 'lanczos', } +_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()} -def _pil_interp(method): - if method == 'bicubic': - return Image.BICUBIC - elif method == 'lanczos': - return Image.LANCZOS - elif method == 'hamming': - return Image.HAMMING +if has_interpolation_mode: + _torch_interpolation_to_str = { + InterpolationMode.NEAREST: 'nearest', + InterpolationMode.BILINEAR: 'bilinear', + InterpolationMode.BICUBIC: 'bicubic', + InterpolationMode.BOX: 'box', + InterpolationMode.HAMMING: 'hamming', + InterpolationMode.LANCZOS: 'lanczos', + } + _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()} +else: + _pil_interpolation_to_torch = {} + _torch_interpolation_to_str = {} + + +def str_to_pil_interp(mode_str): + return _str_to_pil_interpolation[mode_str] + + +def str_to_interp_mode(mode_str): + if has_interpolation_mode: + return _str_to_torch_interpolation[mode_str] + else: + return _str_to_pil_interpolation[mode_str] + + +def interp_mode_to_str(mode): + if has_interpolation_mode: + return _torch_interpolation_to_str[mode] else: - # default bilinear, do we want to allow nearest? - return Image.BILINEAR + return _pil_interpolation_to_str[mode] -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic')) class RandomResizedCropAndInterpolation: @@ -82,7 +109,7 @@ class RandomResizedCropAndInterpolation: if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: - self.interpolation = _pil_interp(interpolation) + self.interpolation = str_to_interp_mode(interpolation) self.scale = scale self.ratio = ratio @@ -146,9 +173,9 @@ class RandomResizedCropAndInterpolation: def __repr__(self): if isinstance(self.interpolation, (tuple, list)): - interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) + interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation]) else: - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = interp_mode_to_str(self.interpolation) format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index df6e0de0..d4815d95 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -10,7 +10,7 @@ from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform -from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy from timm.data.random_erasing import RandomErasing @@ -25,7 +25,7 @@ def transforms_noaug_train( # random interpolation not supported with no-aug interpolation = 'bilinear' tfl = [ - transforms.Resize(img_size, _pil_interp(interpolation)), + transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size) ] if use_prefetcher: @@ -87,7 +87,7 @@ def transforms_imagenet_train( img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) + aa_params['interpolation'] = str_to_pil_interp(interpolation) if auto_augment.startswith('rand'): secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): @@ -147,7 +147,7 @@ def transforms_imagenet_eval( scale_size = int(math.floor(img_size / crop_pct)) tfl = [ - transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)), transforms.CenterCrop(img_size), ] if use_prefetcher: