You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
6.5 KiB
222 lines
6.5 KiB
import torch
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import math
|
|
import numpy as np
|
|
from data.random_erasing import RandomErasingNumpy
|
|
|
|
DEFAULT_CROP_PCT = 0.875
|
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
|
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
|
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
|
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
|
|
|
|
|
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
|
new_config = {}
|
|
default_cfg = default_cfg
|
|
if not default_cfg and hasattr(model, 'default_cfg'):
|
|
default_cfg = model.default_cfg
|
|
|
|
# Resolve input/image size
|
|
# FIXME grayscale/chans arg to use different # channels?
|
|
in_chans = 3
|
|
input_size = (in_chans, 224, 224)
|
|
if args.img_size is not None:
|
|
# FIXME support passing img_size as tuple, non-square
|
|
assert isinstance(args.img_size, int)
|
|
input_size = (in_chans, args.img_size, args.img_size)
|
|
elif 'input_size' in default_cfg:
|
|
input_size = default_cfg['input_size']
|
|
new_config['input_size'] = input_size
|
|
|
|
# resolve interpolation method
|
|
new_config['interpolation'] = 'bilinear'
|
|
if args.interpolation:
|
|
new_config['interpolation'] = args.interpolation
|
|
elif 'interpolation' in default_cfg:
|
|
new_config['interpolation'] = default_cfg['interpolation']
|
|
|
|
# resolve dataset + model mean for normalization
|
|
new_config['mean'] = get_mean_by_model(args.model)
|
|
if args.mean is not None:
|
|
mean = tuple(args.mean)
|
|
if len(mean) == 1:
|
|
mean = tuple(list(mean) * in_chans)
|
|
else:
|
|
assert len(mean) == in_chans
|
|
new_config['mean'] = mean
|
|
elif 'mean' in default_cfg:
|
|
new_config['mean'] = default_cfg['mean']
|
|
|
|
# resolve dataset + model std deviation for normalization
|
|
new_config['std'] = get_std_by_model(args.model)
|
|
if args.std is not None:
|
|
std = tuple(args.std)
|
|
if len(std) == 1:
|
|
std = tuple(list(std) * in_chans)
|
|
else:
|
|
assert len(std) == in_chans
|
|
new_config['std'] = std
|
|
elif 'std' in default_cfg:
|
|
new_config['std'] = default_cfg['std']
|
|
|
|
# resolve default crop percentage
|
|
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
|
if 'crop_pct' in default_cfg:
|
|
new_config['crop_pct'] = default_cfg['crop_pct']
|
|
|
|
if verbose:
|
|
print('Data processing configuration for current model + dataset:')
|
|
for n, v in new_config.items():
|
|
print('\t%s: %s' % (n, str(v)))
|
|
|
|
return new_config
|
|
|
|
|
|
def get_mean_by_name(name):
|
|
if name == 'dpn':
|
|
return IMAGENET_DPN_MEAN
|
|
elif name == 'inception' or name == 'le':
|
|
return IMAGENET_INCEPTION_MEAN
|
|
else:
|
|
return IMAGENET_DEFAULT_MEAN
|
|
|
|
|
|
def get_std_by_name(name):
|
|
if name == 'dpn':
|
|
return IMAGENET_DPN_STD
|
|
elif name == 'inception' or name == 'le':
|
|
return IMAGENET_INCEPTION_STD
|
|
else:
|
|
return IMAGENET_DEFAULT_STD
|
|
|
|
|
|
def get_mean_by_model(model_name):
|
|
model_name = model_name.lower()
|
|
if 'dpn' in model_name:
|
|
return IMAGENET_DPN_STD
|
|
elif 'ception' in model_name or 'nasnet' in model_name:
|
|
return IMAGENET_INCEPTION_MEAN
|
|
else:
|
|
return IMAGENET_DEFAULT_MEAN
|
|
|
|
|
|
def get_std_by_model(model_name):
|
|
model_name = model_name.lower()
|
|
if 'dpn' in model_name:
|
|
return IMAGENET_DEFAULT_STD
|
|
elif 'ception' in model_name or 'nasnet' in model_name:
|
|
return IMAGENET_INCEPTION_STD
|
|
else:
|
|
return IMAGENET_DEFAULT_STD
|
|
|
|
|
|
class ToNumpy:
|
|
|
|
def __call__(self, pil_img):
|
|
np_img = np.array(pil_img, dtype=np.uint8)
|
|
if np_img.ndim < 3:
|
|
np_img = np.expand_dims(np_img, axis=-1)
|
|
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
|
return np_img
|
|
|
|
|
|
class ToTensor:
|
|
|
|
def __init__(self, dtype=torch.float32):
|
|
self.dtype = dtype
|
|
|
|
def __call__(self, pil_img):
|
|
np_img = np.array(pil_img, dtype=np.uint8)
|
|
if np_img.ndim < 3:
|
|
np_img = np.expand_dims(np_img, axis=-1)
|
|
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
|
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
|
|
|
|
|
def _pil_interp(method):
|
|
if method == 'bicubic':
|
|
return Image.BICUBIC
|
|
elif method == 'lanczos':
|
|
return Image.LANCZOS
|
|
elif method == 'hamming':
|
|
return Image.HAMMING
|
|
else:
|
|
# default bilinear, do we want to allow nearest?
|
|
return Image.BILINEAR
|
|
|
|
|
|
def transforms_imagenet_train(
|
|
img_size=224,
|
|
scale=(0.1, 1.0),
|
|
color_jitter=(0.4, 0.4, 0.4),
|
|
interpolation='bilinear',
|
|
random_erasing=0.4,
|
|
use_prefetcher=False,
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD
|
|
):
|
|
|
|
tfl = [
|
|
transforms.RandomResizedCrop(
|
|
img_size, scale=scale,
|
|
interpolation=_pil_interp(interpolation)),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ColorJitter(*color_jitter),
|
|
]
|
|
|
|
if use_prefetcher:
|
|
# prefetcher and collate will handle tensor conversion and norm
|
|
tfl += [ToNumpy()]
|
|
else:
|
|
tfl += [
|
|
ToTensor(),
|
|
transforms.Normalize(
|
|
mean=torch.tensor(mean),
|
|
std=torch.tensor(std))
|
|
]
|
|
if random_erasing > 0.:
|
|
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
|
return transforms.Compose(tfl)
|
|
|
|
|
|
def transforms_imagenet_eval(
|
|
img_size=224,
|
|
crop_pct=None,
|
|
interpolation='bilinear',
|
|
use_prefetcher=False,
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD):
|
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
|
|
|
if isinstance(img_size, tuple):
|
|
assert len(img_size) == 2
|
|
if img_size[-1] == img_size[-2]:
|
|
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
|
scale_size = int(math.floor(img_size[0] / crop_pct))
|
|
else:
|
|
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
|
else:
|
|
scale_size = int(math.floor(img_size / crop_pct))
|
|
|
|
tfl = [
|
|
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
|
transforms.CenterCrop(img_size),
|
|
]
|
|
if use_prefetcher:
|
|
# prefetcher and collate will handle tensor conversion and norm
|
|
tfl += [ToNumpy()]
|
|
else:
|
|
tfl += [
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=torch.tensor(mean),
|
|
std=torch.tensor(std))
|
|
]
|
|
|
|
return transforms.Compose(tfl)
|