You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/transforms.py

198 lines
6.6 KiB

import torch
import torchvision.transforms.functional as F
try:
from torchvision.transforms.functional import InterpolationMode
has_interpolation_mode = True
except ImportError:
has_interpolation_mode = False
from PIL import Image
import warnings
import math
import random
import numpy as np
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
if hasattr(Image, "Resampling"):
_pil_interpolation_to_str = {
Image.Resampling.NEAREST: 'nearest',
Image.Resampling.BILINEAR: 'bilinear',
Image.Resampling.BICUBIC: 'bicubic',
Image.Resampling.BOX: 'box',
Image.Resampling.HAMMING: 'hamming',
Image.Resampling.LANCZOS: 'lanczos',
}
else:
_pil_interpolation_to_str = {
Image.NEAREST: 'nearest',
Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'bicubic',
Image.BOX: 'box',
Image.HAMMING: 'hamming',
Image.LANCZOS: 'lanczos',
}
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
if has_interpolation_mode:
_torch_interpolation_to_str = {
InterpolationMode.NEAREST: 'nearest',
InterpolationMode.BILINEAR: 'bilinear',
InterpolationMode.BICUBIC: 'bicubic',
InterpolationMode.BOX: 'box',
InterpolationMode.HAMMING: 'hamming',
InterpolationMode.LANCZOS: 'lanczos',
}
_str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
else:
_pil_interpolation_to_torch = {}
_torch_interpolation_to_str = {}
def str_to_pil_interp(mode_str):
return _str_to_pil_interpolation[mode_str]
def str_to_interp_mode(mode_str):
if has_interpolation_mode:
return _str_to_torch_interpolation[mode_str]
else:
return _str_to_pil_interpolation[mode_str]
def interp_mode_to_str(mode):
if has_interpolation_mode:
return _torch_interpolation_to_str[mode]
else:
return _pil_interpolation_to_str[mode]
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
class RandomResizedCropAndInterpolation:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear'):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
else:
interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string