@ -4,59 +4,50 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from typing import Union , Tuple
import torch
from torchvision import transforms
from timm . data . constants import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , DEFAULT_CROP_PCT
from timm . data . auto_augment import rand_augment_transform , augment_and_mix_transform , auto_augment_transform
from timm . data . transforms import _pil_interp , RandomResizedCropAndInterpolation , ToNumpy , ToTensor
from timm . data . config import PreprocessCfg , AugCfg
from timm . data . constants import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , DEFAULT_CROP_PCT
from timm . data . random_erasing import RandomErasing
from timm . data . transforms import RandomResizedCropAndInterpolation , ToNumpy , ToTensorNormalize
def transforms_noaug_train (
img_size = 224 ,
img_size : Union [ int , Tuple [ int ] ] = 224 ,
interpolation = ' bilinear ' ,
use_prefetcher = False ,
mean = IMAGENET_DEFAULT_MEAN ,
std = IMAGENET_DEFAULT_STD ,
normalize = False ,
) :
if interpolation == ' random ' :
# random interpolation not supported with no-aug
interpolation = ' bilinear '
tfl = [
transforms . Resize ( img_size , _pil_interp ( interpolation ) ) ,
transforms . Resize ( img_size , transforms. InterpolationMode ( interpolation ) ) ,
transforms . CenterCrop ( img_size )
]
if use_prefetcher :
# prefetcher and collate will handle tensor conversion and norm
tfl + = [ ToNumpy ( ) ]
else :
if normalize :
tfl + = [
transforms . ToTensor ( ) ,
transforms . Normalize (
mean = torch . tensor ( mean ) ,
std = torch . tensor ( std ) )
transforms . Normalize ( mean = torch . tensor ( mean ) , std = torch . tensor ( std ) )
]
else :
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl + = [ ToNumpy ( ) ]
return transforms . Compose ( tfl )
def transforms_imagenet_train (
img_size = 224 ,
scale = None ,
ratio = None ,
hflip = 0.5 ,
vflip = 0. ,
color_jitter = 0.4 ,
auto_augment = None ,
img_size : Union [ int , Tuple [ int ] ] = 224 ,
interpolation = ' random ' ,
use_prefetcher = False ,
mean = IMAGENET_DEFAULT_MEAN ,
std = IMAGENET_DEFAULT_STD ,
re_prob = 0. ,
re_mode = ' const ' ,
re_count = 1 ,
re_num_splits = 0 ,
aug_cfg = AugCfg ( ) ,
normalize = False ,
separate = False ,
) :
"""
@ -66,18 +57,24 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third , final transform
"""
scale = tuple ( scale or ( 0.08 , 1.0 ) ) # default imagenet scale range
ratio = tuple ( ratio or ( 3. / 4. , 4. / 3. ) ) # default imagenet ratio range
scale_range = tuple ( aug_cfg . scale_range or ( 0.08 , 1.0 ) ) # default imagenet scale range
ratio_range = tuple ( aug_cfg . ratio_range or ( 3. / 4. , 4. / 3. ) ) # default imagenet ratio range
# 'primary' train transforms include random resize + crop w/ optional horizontal and vertical flipping aug.
# This is the core of standard ImageNet ResNet and Inception pre-processing
primary_tfl = [
RandomResizedCropAndInterpolation ( img_size , scale = scale , ratio = ratio , interpolation = interpolation ) ]
if hflip > 0. :
primary_tfl + = [ transforms . RandomHorizontalFlip ( p = hflip) ]
if vflip > 0. :
primary_tfl + = [ transforms . RandomVerticalFlip ( p = vflip) ]
RandomResizedCropAndInterpolation ( img_size , scale = scale _range , ratio = ratio _range , interpolation = interpolation ) ]
if aug_cfg. hflip_prob > 0. :
primary_tfl + = [ transforms . RandomHorizontalFlip ( p = aug_cfg. hflip_prob ) ]
if aug_cfg. vflip_prob > 0. :
primary_tfl + = [ transforms . RandomVerticalFlip ( p = aug_cfg. vflip_prob ) ]
# 'secondary' transform stage includes either color jitter (could add lighting too) or auto-augmentations
# such as AutoAugment, RandAugment, AugMix, etc
secondary_tfl = [ ]
if auto_augment :
assert isinstance ( auto_augment , str )
if aug_cfg . auto_augment :
aa = aug_cfg . auto_augment
assert isinstance ( aa , str )
if isinstance ( img_size , ( tuple , list ) ) :
img_size_min = min ( img_size )
else :
@ -87,58 +84,63 @@ def transforms_imagenet_train(
img_mean = tuple ( [ min ( 255 , round ( 255 * x ) ) for x in mean ] ) ,
)
if interpolation and interpolation != ' random ' :
aa_params [ ' interpolation ' ] = _pil_interp( interpolation)
if a uto_ augment . startswith ( ' rand ' ) :
secondary_tfl + = [ rand_augment_transform ( a uto_ augment , aa_params ) ]
elif a uto_ augment . startswith ( ' augmix ' ) :
aa_params [ ' interpolation ' ] = interpolation
if a a. startswith ( ' rand ' ) :
secondary_tfl + = [ rand_augment_transform ( a a, aa_params ) ]
elif a a. startswith ( ' augmix ' ) :
aa_params [ ' translate_pct ' ] = 0.3
secondary_tfl + = [ augment_and_mix_transform ( a uto_ augment , aa_params ) ]
secondary_tfl + = [ augment_and_mix_transform ( a a, aa_params ) ]
else :
secondary_tfl + = [ auto_augment_transform ( a uto_ augment , aa_params ) ]
elif color_jitter is not None :
secondary_tfl + = [ auto_augment_transform ( a a, aa_params ) ]
elif aug_cfg. color_jitter is not None :
# color jitter is enabled when not using AA
if isinstance ( color_jitter , ( list , tuple ) ) :
cj = aug_cfg . color_jitter
if isinstance ( cj , ( list , tuple ) ) :
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len ( c olor_ jitter ) in ( 3 , 4 )
assert len ( c j) in ( 3 , 4 )
else :
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
c olor_ jitter = ( float ( c olor_ jitter ) , ) * 3
secondary_tfl + = [ transforms . ColorJitter ( * c olor_ jitter ) ]
c j = ( float ( c j) , ) * 3
secondary_tfl + = [ transforms . ColorJitter ( * c j) ]
# 'final' transform stage includes normalization, followed by optional random erasing and tensor conversion
final_tfl = [ ]
if use_prefetcher :
# prefetcher and collate will handle tensor conversion and norm
final_tfl + = [ ToNumpy ( ) ]
else :
if normalize :
final_tfl + = [
transforms . ToTensor ( ) ,
transforms . Normalize (
mean = torch . tensor ( mean ) ,
std = torch . tensor ( std ) )
ToTensorNormalize ( mean = mean , std = std )
]
if re_prob > 0. :
final_tfl . append (
RandomErasing ( re_prob , mode = re_mode , max_count = re_count , num_splits = re_num_splits , device = ' cpu ' ) )
if aug_cfg . re_prob > 0. :
final_tfl . append ( RandomErasing (
aug_cfg . re_prob ,
mode = aug_cfg . re_mode ,
count = aug_cfg . re_count ,
num_splits = aug_cfg . num_aug_splits ) )
else :
# when normalize disabled, (pre)fetcher and collate will handle tensor conversion and normalize
final_tfl + = [ ToNumpy ( ) ]
if separate :
# return each transform stage separately
return transforms . Compose ( primary_tfl ) , transforms . Compose ( secondary_tfl ) , transforms . Compose ( final_tfl )
else :
return transforms . Compose ( primary_tfl + secondary_tfl + final_tfl )
def transforms_imagenet_eval (
img_size = 224 ,
img_size : Union [ int , Tuple [ int ] ] = 224 ,
crop_pct = None ,
interpolation = ' bilinear ' ,
use_prefetcher = False ,
mean = IMAGENET_DEFAULT_MEAN ,
std = IMAGENET_DEFAULT_STD ) :
std = IMAGENET_DEFAULT_STD ,
normalize = False ,
) :
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance ( img_size , ( tuple , list ) ) :
assert len ( img_size ) == 2
if img_size [ - 1 ] == img_size [ - 2 ] :
# FIXME handle case where img is square and we want non aspect preserving resize
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int ( math . floor ( img_size [ 0 ] / crop_pct ) )
else :
@ -147,27 +149,87 @@ def transforms_imagenet_eval(
scale_size = int ( math . floor ( img_size / crop_pct ) )
tfl = [
transforms . Resize ( scale_size , _pil_interp ( interpolation ) ) ,
transforms . Resize ( scale_size , transforms. InterpolationMode ( interpolation ) ) ,
transforms . CenterCrop ( img_size ) ,
]
if use_prefetcher :
# prefetcher and collate will handle tensor conversion and norm
tfl + = [ ToNumpy ( ) ]
else :
if normalize :
tfl + = [
transforms . ToTensor ( ) ,
transforms . Normalize (
mean = torch . tensor ( mean ) ,
std = torch . tensor ( std ) )
ToTensorNormalize ( mean = mean , std = std )
]
else :
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl + = [ ToNumpy ( ) ]
return transforms . Compose ( tfl )
def create_transform_v2 (
cfg = PreprocessCfg ( ) ,
is_training = False ,
normalize = False ,
separate = False ,
tf_preprocessing = False ,
) :
"""
Args :
cfg : Pre - processing configuration
is_training ( bool ) : Create transform for training pre - processing
tf_preprocessing ( bool ) : Use Tensorflow pre - processing ( for validation )
normalize ( bool ) : Enable normalization in transforms ( otherwise handled by fetcher / pre - fetcher )
separate ( bool ) : Return transforms separated into stages ( for train )
Returns :
"""
input_size = cfg . input_size
if isinstance ( input_size , ( tuple , list ) ) :
img_size = input_size [ - 2 : ]
else :
img_size = input_size
if tf_preprocessing :
assert not normalize , " Expecting normalization to be handled in (pre)fetcher w/ TF preprocessing "
assert not separate , " Separate transforms not supported for TF preprocessing "
from timm . data . tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform (
is_training = is_training , size = img_size , interpolation = cfg . interpolation )
else :
if is_training and cfg . aug is None :
assert not separate , " Cannot perform split augmentation with no_aug "
transform = transforms_noaug_train (
img_size ,
interpolation = cfg . interpolation ,
normalize = normalize ,
mean = cfg . mean ,
std = cfg . std )
elif is_training :
transform = transforms_imagenet_train (
img_size ,
interpolation = cfg . interpolation ,
mean = cfg . mean ,
std = cfg . std ,
aug_cfg = cfg . aug ,
normalize = normalize ,
separate = separate )
else :
assert not separate , " Separate transforms not supported for validation preprocessing "
transform = transforms_imagenet_eval (
img_size ,
interpolation = cfg . interpolation ,
crop_pct = cfg . crop_pct ,
mean = cfg . mean ,
std = cfg . std ,
normalize = normalize ,
)
return transform
def create_transform (
input_size ,
is_training = False ,
use_fetcher = False ,
use_ pre fetcher= False ,
no_aug = False ,
scale = None ,
ratio = None ,
@ -191,7 +253,8 @@ def create_transform(
else :
img_size = input_size
if tf_preprocessing and use_fetcher :
normalize_in_transform = not use_prefetcher
if tf_preprocessing and use_prefetcher :
assert not separate , " Separate transforms not supported for TF preprocessing "
from timm . data . tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform (
@ -202,35 +265,41 @@ def create_transform(
transform = transforms_noaug_train (
img_size ,
interpolation = interpolation ,
use_prefetcher = use_fetcher ,
mean = mean ,
std = std )
std = std ,
normalize = normalize_in_transform ,
)
elif is_training :
transform = transforms_imagenet_train (
img_size ,
scale = scale ,
ratio = ratio ,
hflip = hflip ,
vflip = vflip ,
aug_cfg = AugCfg (
scale_range = scale ,
ratio_range = ratio ,
hflip_prob = hflip ,
vflip_prob = vflip ,
color_jitter = color_jitter ,
auto_augment = auto_augment ,
interpolation = interpolation ,
use_prefetcher = use_fetcher ,
mean = mean ,
std = std ,
re_prob = re_prob ,
re_mode = re_mode ,
re_count = re_count ,
re_num_splits = re_num_splits ,
separate = separate )
num_aug_splits = re_num_splits ,
)
transform = transforms_imagenet_train (
img_size ,
interpolation = interpolation ,
mean = mean ,
std = std ,
aug_cfg = aug_cfg ,
normalize = normalize_in_transform ,
separate = separate
)
else :
assert not separate , " Separate transforms not supported for validation preprocessing "
assert not separate , " Separate transforms not supported for validation pre - processing"
transform = transforms_imagenet_eval (
img_size ,
interpolation = interpolation ,
use_prefetcher = use_fetcher ,
mean = mean ,
std = std ,
crop_pct = crop_pct )
crop_pct = crop_pct ,
normalize = normalize_in_transform ,
)
return transform