Some transform/data/loader refactoring, hopefully didn't break things

* factor out data related constants to own file
* move data related config helpers to own file
* add a variant of RandomResizeCrop that randomizes interpolation method
* remove old Numpy version of RandomErasing
* cleanup torch version of RandomErasing and use it in either GPU loader batch mode or single image cpu Transform
pull/2/head
Ross Wightman 6 years ago
parent e3377b0409
commit 76539d905e

@ -1,4 +1,5 @@
from data.constants import *
from data.config import resolve_data_config
from data.dataset import Dataset from data.dataset import Dataset
from data.transforms import * from data.transforms import *
from data.loader import create_loader from data.loader import create_loader
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy

@ -0,0 +1,101 @@
from data.constants import *
def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3
input_size = (in_chans, 224, 224)
if args.img_size is not None:
# FIXME support passing img_size as tuple, non-square
assert isinstance(args.img_size, int)
input_size = (in_chans, args.img_size, args.img_size)
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bilinear'
if args.interpolation:
new_config['interpolation'] = args.interpolation
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model)
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif 'std' in default_cfg:
new_config['std'] = default_cfg['std']
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
print('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
print('\t%s: %s' % (n, str(v)))
return new_config
def get_mean_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD

@ -0,0 +1,7 @@
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)

@ -1,6 +1,4 @@
import torch
import torch.utils.data import torch.utils.data
from data.random_erasing import RandomErasingTorch
from data.transforms import * from data.transforms import *
from data.distributed_sampler import OrderedDistributedSampler from data.distributed_sampler import OrderedDistributedSampler
@ -27,7 +25,7 @@ class PrefetchLoader:
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if rand_erase_prob > 0.: if rand_erase_prob > 0.:
self.random_erasing = RandomErasingTorch( self.random_erasing = RandomErasing(
probability=rand_erase_prob, per_pixel=rand_erase_pp) probability=rand_erase_prob, per_pixel=rand_erase_pp)
else: else:
self.random_erasing = None self.random_erasing = None

@ -2,125 +2,68 @@ from __future__ import absolute_import
import random import random
import math import math
import numpy as np
import torch import torch
class RandomErasingNumpy: def _get_patch(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
""" Randomly selects a rectangle region in an image and erases its pixels. if per_pixel:
'Random Erasing Data Augmentation' by Zhong et al. return torch.empty(
See https://arxiv.org/pdf/1708.04896.pdf patch_size, dtype=dtype, device=device).normal_()
elif rand_color:
This 'Numpy' variant of RandomErasing is intended to be applied on a per return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
image basis after transforming the image to uint8 numpy array in
range 0-255 prior to tensor conversion and normalization
Args:
probability: The probability that the Random Erasing operation will be performed.
sl: Minimum proportion of erased area against input image.
sh: Maximum proportion of erased area against input image.
r1: Minimum aspect ratio of erased area.
mean: Erasing value.
"""
def __init__(
self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
per_pixel=False, rand_color=False,
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
out_type=np.uint8):
self.probability = probability
if not per_pixel and not rand_color:
self.mean = np.array(mean).round().astype(out_type)
else: else:
self.mean = None return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
self.sl = sl
self.sh = sh
self.min_aspect = min_aspect
self.pl = pl
self.ph = ph
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
self.rand_color = rand_color # per block random, bounded by [pl, ph]
self.out_type = out_type
def __call__(self, img):
if random.random() > self.probability:
return img
chan, img_h, img_w = img.shape class RandomErasing:
area = img_h * img_w
for attempt in range(100):
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color:
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
elif not self.per_pixel:
c = self.mean[:chan]
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
if self.per_pixel:
img[:, top:top + h, left:left + w] = np.random.randint(
self.pl, self.ph + 1, (chan, h, w), self.out_type)
else:
img[:, top:top + h, left:left + w] = c
return img
return img
class RandomErasingTorch:
""" Randomly selects a rectangle region in an image and erases its pixels. """ Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. 'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf See https://arxiv.org/pdf/1708.04896.pdf
This 'Torch' variant of RandomErasing is intended to be applied to a full batch This variant of RandomErasing is intended to be applied to either a batch
tensor after it has been normalized by dataset mean and std. or single image tensor after it has been normalized by dataset mean and std.
Args: Args:
probability: The probability that the Random Erasing operation will be performed. probability: The probability that the Random Erasing operation will be performed.
sl: Minimum proportion of erased area against input image. sl: Minimum proportion of erased area against input image.
sh: Maximum proportion of erased area against input image. sh: Maximum proportion of erased area against input image.
r1: Minimum aspect ratio of erased area. min_aspect: Minimum aspect ratio of erased area.
per_pixel: random value for each pixel in the erase region, precedence over rand_color
rand_color: random color for whole erase region, 0 if neither this or per_pixel set
""" """
def __init__( def __init__(
self, self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
per_pixel=False, rand_color=False): per_pixel=False, rand_color=False, device='cuda'):
self.probability = probability self.probability = probability
self.sl = sl self.sl = sl
self.sh = sh self.sh = sh
self.min_aspect = min_aspect self.min_aspect = min_aspect
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
self.rand_color = rand_color # per block random, bounded by [pl, ph] self.rand_color = rand_color # per block random, bounded by [pl, ph]
self.device = device
def __call__(self, batch): def _erase(self, img, chan, img_h, img_w, dtype):
batch_size, chan, img_h, img_w = batch.size()
area = img_h * img_w
for i in range(batch_size):
if random.random() > self.probability: if random.random() > self.probability:
continue return
img = batch[i] area = img_h * img_w
for attempt in range(100): for attempt in range(100):
target_area = random.uniform(self.sl, self.sh) * area target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
h = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color:
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
elif not self.per_pixel:
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
if w < img_w and h < img_h: if w < img_w and h < img_h:
top = random.randint(0, img_h - h) top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w) left = random.randint(0, img_w - w)
if self.per_pixel: img[:, top:top + h, left:left + w] = _get_patch(
img[:, top:top + h, left:left + w] = torch.empty( self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device)
(chan, h, w), dtype=batch.dtype).normal_().cuda()
else:
img[:, top:top + h, left:left + w] = c
break break
return batch def __call__(self, input):
if len(input.size()) == 3:
self._erase(input, *input.size(), input.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
for i in range(batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input

@ -1,118 +1,14 @@
import torch import torch
from torchvision import transforms from torchvision import transforms
import torchvision.transforms.functional as F
from PIL import Image from PIL import Image
import warnings
import math import math
import random
import numpy as np import numpy as np
from data.random_erasing import RandomErasingNumpy
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
# Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3
input_size = (in_chans, 224, 224)
if args.img_size is not None:
# FIXME support passing img_size as tuple, non-square
assert isinstance(args.img_size, int)
input_size = (in_chans, args.img_size, args.img_size)
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bilinear'
if args.interpolation:
new_config['interpolation'] = args.interpolation
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model)
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
new_config['mean'] = default_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif 'std' in default_cfg:
new_config['std'] = default_cfg['std']
# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if verbose:
print('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
print('\t%s: %s' % (n, str(v)))
return new_config
from data import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def get_mean_by_name(name): from data.random_erasing import RandomErasing
if name == 'dpn':
return IMAGENET_DPN_MEAN
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_name(name):
if name == 'dpn':
return IMAGENET_DPN_STD
elif name == 'inception' or name == 'le':
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
def get_mean_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_std_by_model(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name or 'nasnet' in model_name:
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
class ToNumpy: class ToNumpy:
@ -138,6 +34,16 @@ class ToTensor:
return torch.from_numpy(np_img).to(dtype=self.dtype) return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def _pil_interp(method): def _pil_interp(method):
if method == 'bicubic': if method == 'bicubic':
return Image.BICUBIC return Image.BICUBIC
@ -150,21 +56,118 @@ def _pil_interp(method):
return Image.BILINEAR return Image.BILINEAR
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
class RandomResizedCropAndInterpolation(object):
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear'):
if isinstance(size, tuple):
self.size = size
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
if interpolation == 'random':
self.interpolation = RANDOM_INTERPOLATION
else:
self.interpolation = _pil_interp(interpolation)
self.scale = scale
self.ratio = ratio
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
area = img.size[0] * img.size[1]
for attempt in range(10):
target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
return i, j, w, w
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string
def transforms_imagenet_train( def transforms_imagenet_train(
img_size=224, img_size=224,
scale=(0.1, 1.0), scale=(0.08, 1.0),
color_jitter=(0.4, 0.4, 0.4), color_jitter=(0.4, 0.4, 0.4),
interpolation='bilinear', interpolation='random',
random_erasing=0.4, random_erasing=0.4,
random_erasing_pp=True,
use_prefetcher=False, use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD std=IMAGENET_DEFAULT_STD
): ):
tfl = [ tfl = [
transforms.RandomResizedCrop( RandomResizedCropAndInterpolation(
img_size, scale=scale, img_size, scale=scale, interpolation=interpolation),
interpolation=_pil_interp(interpolation)),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter), transforms.ColorJitter(*color_jitter),
] ]
@ -174,13 +177,13 @@ def transforms_imagenet_train(
tfl += [ToNumpy()] tfl += [ToNumpy()]
else: else:
tfl += [ tfl += [
ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=torch.tensor(mean), mean=torch.tensor(mean),
std=torch.tensor(std)) std=torch.tensor(std))
] ]
if random_erasing > 0.: if random_erasing > 0.:
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
return transforms.Compose(tfl) return transforms.Compose(tfl)

@ -9,7 +9,7 @@ from collections import OrderedDict
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import * from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import re import re
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']

@ -17,7 +17,7 @@ from collections import OrderedDict
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import select_adaptive_pool2d from models.adaptive_avgmax_pool import select_adaptive_pool2d
from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']

@ -24,7 +24,7 @@ import torch.nn.functional as F
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from models.conv2d_same import sconv2d from models.conv2d_same import sconv2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', __all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140',
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small',

@ -7,8 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import * from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
default_cfgs = { default_cfgs = {
'inception_resnet_v2': { 'inception_resnet_v2': {

@ -7,8 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import * from models.adaptive_avgmax_pool import *
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
default_cfgs = { default_cfgs = {
'inception_v4': { 'inception_v4': {

@ -10,7 +10,7 @@ import torch.nn.functional as F
import math import math
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d'] 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']

@ -17,7 +17,7 @@ import torch.nn.functional as F
from models.helpers import load_pretrained from models.helpers import load_pretrained
from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.adaptive_avgmax_pool import SelectAdaptivePool2d
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d'] 'seresnext50_32x4d', 'seresnext101_32x4d']

@ -10,7 +10,7 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
from data import * from data import Dataset, create_loader, resolve_data_config
from models import create_model, resume_checkpoint from models import create_model, resume_checkpoint
from utils import * from utils import *
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
@ -224,7 +224,7 @@ def main():
use_prefetcher=True, use_prefetcher=True,
rand_erase_prob=args.reprob, rand_erase_prob=args.reprob,
rand_erase_pp=args.repp, rand_erase_pp=args.repp,
interpolation=data_config['interpolation'], interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,

Loading…
Cancel
Save