You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/transforms_factory.py

185 lines
6.4 KiB

""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
"""
import math
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from timm.data.random_erasing import RandomErasing
def transforms_imagenet_train(
img_size=224,
scale=(0.08, 1.0),
color_jitter=0.4,
auto_augment=None,
interpolation='random',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
separate=False,
):
"""
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
if re_prob > 0.:
final_tfl.append(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(tfl)
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
crop_pct=None,
tf_preprocessing=False,
separate=False):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=interpolation)
else:
if is_training:
transform = transforms_imagenet_train(
img_size,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
return transform