You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/transforms.py

343 lines
12 KiB

import math
import numbers
import random
import warnings
from typing import List, Sequence
import torch
import torchvision.transforms.functional as F
try:
from torchvision.transforms.functional import InterpolationMode
has_interpolation_mode = True
except ImportError:
has_interpolation_mode = False
from PIL import Image
import numpy as np
class ToNumpy:
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
if hasattr(Image, "Resampling"):
_pil_interpolation_to_str = {
Image.Resampling.NEAREST: 'nearest',
Image.Resampling.BILINEAR: 'bilinear',
Image.Resampling.BICUBIC: 'bicubic',
Image.Resampling.BOX: 'box',
Image.Resampling.HAMMING: 'hamming',
Image.Resampling.LANCZOS: 'lanczos',
}
else:
_pil_interpolation_to_str = {
Image.NEAREST: 'nearest',
Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'bicubic',
Image.BOX: 'box',
Image.HAMMING: 'hamming',
Image.LANCZOS: 'lanczos',
}
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
if has_interpolation_mode:
_torch_interpolation_to_str = {
InterpolationMode.NEAREST: 'nearest',
InterpolationMode.BILINEAR: 'bilinear',
InterpolationMode.BICUBIC: 'bicubic',
InterpolationMode.BOX: 'box',
InterpolationMode.HAMMING: 'hamming',
InterpolationMode.LANCZOS: 'lanczos',
}
_str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
else:
_pil_interpolation_to_torch = {}
_torch_interpolation_to_str = {}
def str_to_pil_interp(mode_str):
return _str_to_pil_interpolation[mode_str]
def str_to_interp_mode(mode_str):
if has_interpolation_mode:
return _str_to_torch_interpolation[mode_str]
else:
return _str_to_pil_interpolation[mode_str]
def interp_mode_to_str(mode):
if has_interpolation_mode:
return _torch_interpolation_to_str[mode]
else:
return _pil_interpolation_to_str[mode]
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
class RandomResizedCropAndInterpolation:
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear'):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, h, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
else:
interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
fill (int, Tuple[int]): Padding color
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
class CenterCropOrPad(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
class ResizeKeepRatio:
""" Resize and Keep Ratio
"""
def __init__(
self,
size,
longest=0.,
interpolation='bilinear',
fill=0,
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
self.interpolation = str_to_interp_mode(interpolation)
self.longest = float(longest)
self.fill = fill
@staticmethod
def get_params(img, target_size, longest):
"""Get parameters
Args:
img (PIL Image): Image to be cropped.
target_size (Tuple[int, int]): Size of output
Returns:
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
"""
source_size = img.size[::-1] # h, w
h, w = source_size
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
size = [round(x / ratio) for x in source_size]
return size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(img, self.size, self.longest)
img = F.resize(img, size, self.interpolation)
return img
def __repr__(self):
interpolate_str = interp_mode_to_str(self.interpolation)
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={interpolate_str})'
format_string += f', longest={self.longest:.3f})'
return format_string