diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 121a3fc6..1b51ccb4 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -36,11 +36,16 @@ _HPARAMS_DEFAULT = dict( img_mean=_FILL, ) -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +if hasattr(Image, "Resampling"): + _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) + _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC +else: + _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + _DEFAULT_INTERPOLATION = Image.BICUBIC def _interpolation(kwargs): - interpolation = kwargs.pop('resample', Image.BILINEAR) + interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: