diff --git a/data/__init__.py b/data/__init__.py index f1e1c182..418d064a 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,5 @@ +from data.constants import * +from data.config import resolve_data_config from data.dataset import Dataset from data.transforms import * from data.loader import create_loader -from data.random_erasing import RandomErasingTorch, RandomErasingNumpy \ No newline at end of file diff --git a/data/config.py b/data/config.py new file mode 100644 index 00000000..5b15dba8 --- /dev/null +++ b/data/config.py @@ -0,0 +1,101 @@ +from data.constants import * + + +def resolve_data_config(model, args, default_cfg={}, verbose=True): + new_config = {} + default_cfg = default_cfg + if not default_cfg and hasattr(model, 'default_cfg'): + default_cfg = model.default_cfg + + # Resolve input/image size + # FIXME grayscale/chans arg to use different # channels? + in_chans = 3 + input_size = (in_chans, 224, 224) + if args.img_size is not None: + # FIXME support passing img_size as tuple, non-square + assert isinstance(args.img_size, int) + input_size = (in_chans, args.img_size, args.img_size) + elif 'input_size' in default_cfg: + input_size = default_cfg['input_size'] + new_config['input_size'] = input_size + + # resolve interpolation method + new_config['interpolation'] = 'bilinear' + if args.interpolation: + new_config['interpolation'] = args.interpolation + elif 'interpolation' in default_cfg: + new_config['interpolation'] = default_cfg['interpolation'] + + # resolve dataset + model mean for normalization + new_config['mean'] = get_mean_by_model(args.model) + if args.mean is not None: + mean = tuple(args.mean) + if len(mean) == 1: + mean = tuple(list(mean) * in_chans) + else: + assert len(mean) == in_chans + new_config['mean'] = mean + elif 'mean' in default_cfg: + new_config['mean'] = default_cfg['mean'] + + # resolve dataset + model std deviation for normalization + new_config['std'] = get_std_by_model(args.model) + if args.std is not None: + std = tuple(args.std) + if len(std) == 1: + std = tuple(list(std) * in_chans) + else: + assert len(std) == in_chans + new_config['std'] = std + elif 'std' in default_cfg: + new_config['std'] = default_cfg['std'] + + # resolve default crop percentage + new_config['crop_pct'] = DEFAULT_CROP_PCT + if 'crop_pct' in default_cfg: + new_config['crop_pct'] = default_cfg['crop_pct'] + + if verbose: + print('Data processing configuration for current model + dataset:') + for n, v in new_config.items(): + print('\t%s: %s' % (n, str(v))) + + return new_config + + +def get_mean_by_name(name): + if name == 'dpn': + return IMAGENET_DPN_MEAN + elif name == 'inception' or name == 'le': + return IMAGENET_INCEPTION_MEAN + else: + return IMAGENET_DEFAULT_MEAN + + +def get_std_by_name(name): + if name == 'dpn': + return IMAGENET_DPN_STD + elif name == 'inception' or name == 'le': + return IMAGENET_INCEPTION_STD + else: + return IMAGENET_DEFAULT_STD + + +def get_mean_by_model(model_name): + model_name = model_name.lower() + if 'dpn' in model_name: + return IMAGENET_DPN_STD + elif 'ception' in model_name or 'nasnet' in model_name: + return IMAGENET_INCEPTION_MEAN + else: + return IMAGENET_DEFAULT_MEAN + + +def get_std_by_model(model_name): + model_name = model_name.lower() + if 'dpn' in model_name: + return IMAGENET_DEFAULT_STD + elif 'ception' in model_name or 'nasnet' in model_name: + return IMAGENET_INCEPTION_STD + else: + return IMAGENET_DEFAULT_STD diff --git a/data/constants.py b/data/constants.py new file mode 100644 index 00000000..d6d4a01b --- /dev/null +++ b/data/constants.py @@ -0,0 +1,7 @@ +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) diff --git a/data/loader.py b/data/loader.py index 2402b093..31663c9c 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,6 +1,4 @@ -import torch import torch.utils.data -from data.random_erasing import RandomErasingTorch from data.transforms import * from data.distributed_sampler import OrderedDistributedSampler @@ -27,7 +25,7 @@ class PrefetchLoader: self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) if rand_erase_prob > 0.: - self.random_erasing = RandomErasingTorch( + self.random_erasing = RandomErasing( probability=rand_erase_prob, per_pixel=rand_erase_pp) else: self.random_erasing = None diff --git a/data/random_erasing.py b/data/random_erasing.py index 668a3831..8434179c 100644 --- a/data/random_erasing.py +++ b/data/random_erasing.py @@ -2,125 +2,68 @@ from __future__ import absolute_import import random import math -import numpy as np import torch -class RandomErasingNumpy: +def _get_patch(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): + if per_pixel: + return torch.empty( + patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: """ Randomly selects a rectangle region in an image and erases its pixels. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf - This 'Numpy' variant of RandomErasing is intended to be applied on a per - image basis after transforming the image to uint8 numpy array in - range 0-255 prior to tensor conversion and normalization + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. Args: probability: The probability that the Random Erasing operation will be performed. sl: Minimum proportion of erased area against input image. sh: Maximum proportion of erased area against input image. - r1: Minimum aspect ratio of erased area. - mean: Erasing value. + min_aspect: Minimum aspect ratio of erased area. + per_pixel: random value for each pixel in the erase region, precedence over rand_color + rand_color: random color for whole erase region, 0 if neither this or per_pixel set """ def __init__( self, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - per_pixel=False, rand_color=False, - pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406], - out_type=np.uint8): + per_pixel=False, rand_color=False, device='cuda'): self.probability = probability - if not per_pixel and not rand_color: - self.mean = np.array(mean).round().astype(out_type) - else: - self.mean = None self.sl = sl self.sh = sh self.min_aspect = min_aspect - self.pl = pl - self.ph = ph self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] self.rand_color = rand_color # per block random, bounded by [pl, ph] - self.out_type = out_type + self.device = device - def __call__(self, img): + def _erase(self, img, chan, img_h, img_w, dtype): if random.random() > self.probability: - return img - - chan, img_h, img_w = img.shape + return area = img_h * img_w for attempt in range(100): target_area = random.uniform(self.sl, self.sh) * area aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) - if self.rand_color: - c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type) - elif not self.per_pixel: - c = self.mean[:chan] if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) - if self.per_pixel: - img[:, top:top + h, left:left + w] = np.random.randint( - self.pl, self.ph + 1, (chan, h, w), self.out_type) - else: - img[:, top:top + h, left:left + w] = c - return img - - return img - - -class RandomErasingTorch: - """ Randomly selects a rectangle region in an image and erases its pixels. - 'Random Erasing Data Augmentation' by Zhong et al. - See https://arxiv.org/pdf/1708.04896.pdf + img[:, top:top + h, left:left + w] = _get_patch( + self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device) + break - This 'Torch' variant of RandomErasing is intended to be applied to a full batch - tensor after it has been normalized by dataset mean and std. - Args: - probability: The probability that the Random Erasing operation will be performed. - sl: Minimum proportion of erased area against input image. - sh: Maximum proportion of erased area against input image. - r1: Minimum aspect ratio of erased area. - """ - - def __init__( - self, - probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - per_pixel=False, rand_color=False): - self.probability = probability - self.sl = sl - self.sh = sh - self.min_aspect = min_aspect - self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] - self.rand_color = rand_color # per block random, bounded by [pl, ph] - - def __call__(self, batch): - batch_size, chan, img_h, img_w = batch.size() - area = img_h * img_w - for i in range(batch_size): - if random.random() > self.probability: - continue - img = batch[i] - for attempt in range(100): - target_area = random.uniform(self.sl, self.sh) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if self.rand_color: - c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda() - elif not self.per_pixel: - c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda() - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - if self.per_pixel: - img[:, top:top + h, left:left + w] = torch.empty( - (chan, h, w), dtype=batch.dtype).normal_().cuda() - else: - img[:, top:top + h, left:left + w] = c - break - - return batch + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + for i in range(batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/data/transforms.py b/data/transforms.py index bebbd15a..01141086 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -1,118 +1,14 @@ import torch from torchvision import transforms +import torchvision.transforms.functional as F from PIL import Image +import warnings import math +import random import numpy as np -from data.random_erasing import RandomErasingNumpy - -DEFAULT_CROP_PCT = 0.875 - -IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) -IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) -IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) -IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) -IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) -IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) - - -def resolve_data_config(model, args, default_cfg={}, verbose=True): - new_config = {} - default_cfg = default_cfg - if not default_cfg and hasattr(model, 'default_cfg'): - default_cfg = model.default_cfg - - # Resolve input/image size - # FIXME grayscale/chans arg to use different # channels? - in_chans = 3 - input_size = (in_chans, 224, 224) - if args.img_size is not None: - # FIXME support passing img_size as tuple, non-square - assert isinstance(args.img_size, int) - input_size = (in_chans, args.img_size, args.img_size) - elif 'input_size' in default_cfg: - input_size = default_cfg['input_size'] - new_config['input_size'] = input_size - - # resolve interpolation method - new_config['interpolation'] = 'bilinear' - if args.interpolation: - new_config['interpolation'] = args.interpolation - elif 'interpolation' in default_cfg: - new_config['interpolation'] = default_cfg['interpolation'] - - # resolve dataset + model mean for normalization - new_config['mean'] = get_mean_by_model(args.model) - if args.mean is not None: - mean = tuple(args.mean) - if len(mean) == 1: - mean = tuple(list(mean) * in_chans) - else: - assert len(mean) == in_chans - new_config['mean'] = mean - elif 'mean' in default_cfg: - new_config['mean'] = default_cfg['mean'] - - # resolve dataset + model std deviation for normalization - new_config['std'] = get_std_by_model(args.model) - if args.std is not None: - std = tuple(args.std) - if len(std) == 1: - std = tuple(list(std) * in_chans) - else: - assert len(std) == in_chans - new_config['std'] = std - elif 'std' in default_cfg: - new_config['std'] = default_cfg['std'] - - # resolve default crop percentage - new_config['crop_pct'] = DEFAULT_CROP_PCT - if 'crop_pct' in default_cfg: - new_config['crop_pct'] = default_cfg['crop_pct'] - - if verbose: - print('Data processing configuration for current model + dataset:') - for n, v in new_config.items(): - print('\t%s: %s' % (n, str(v))) - - return new_config - -def get_mean_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_MEAN - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_MEAN - else: - return IMAGENET_DEFAULT_MEAN - - -def get_std_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_STD - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_STD - else: - return IMAGENET_DEFAULT_STD - - -def get_mean_by_model(model_name): - model_name = model_name.lower() - if 'dpn' in model_name: - return IMAGENET_DPN_STD - elif 'ception' in model_name or 'nasnet' in model_name: - return IMAGENET_INCEPTION_MEAN - else: - return IMAGENET_DEFAULT_MEAN - - -def get_std_by_model(model_name): - model_name = model_name.lower() - if 'dpn' in model_name: - return IMAGENET_DEFAULT_STD - elif 'ception' in model_name or 'nasnet' in model_name: - return IMAGENET_INCEPTION_STD - else: - return IMAGENET_DEFAULT_STD +from data import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from data.random_erasing import RandomErasing class ToNumpy: @@ -138,6 +34,16 @@ class ToTensor: return torch.from_numpy(np_img).to(dtype=self.dtype) +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + def _pil_interp(method): if method == 'bicubic': return Image.BICUBIC @@ -150,21 +56,118 @@ def _pil_interp(method): return Image.BILINEAR +RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +class RandomResizedCropAndInterpolation(object): + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), + interpolation='bilinear'): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + if interpolation == 'random': + self.interpolation = RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio): + w, h = h, w + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback + w = min(img.size[0], img.size[1]) + i = (img.size[1] - w) // 2 + j = (img.size[0] - w) // 2 + return i, j, w, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + def transforms_imagenet_train( img_size=224, - scale=(0.1, 1.0), + scale=(0.08, 1.0), color_jitter=(0.4, 0.4, 0.4), - interpolation='bilinear', + interpolation='random', random_erasing=0.4, + random_erasing_pp=True, use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ): tfl = [ - transforms.RandomResizedCrop( - img_size, scale=scale, - interpolation=_pil_interp(interpolation)), + RandomResizedCropAndInterpolation( + img_size, scale=scale, interpolation=interpolation), transforms.RandomHorizontalFlip(), transforms.ColorJitter(*color_jitter), ] @@ -174,13 +177,13 @@ def transforms_imagenet_train( tfl += [ToNumpy()] else: tfl += [ - ToTensor(), + transforms.ToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std)) ] if random_erasing > 0.: - tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) + tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu')) return transforms.Compose(tfl) diff --git a/models/densenet.py b/models/densenet.py index 55e726cd..8a2e4369 100644 --- a/models/densenet.py +++ b/models/densenet.py @@ -9,7 +9,7 @@ from collections import OrderedDict from models.helpers import load_pretrained from models.adaptive_avgmax_pool import * -from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] diff --git a/models/dpn.py b/models/dpn.py index fde1a08e..77bd5c15 100644 --- a/models/dpn.py +++ b/models/dpn.py @@ -17,7 +17,7 @@ from collections import OrderedDict from models.helpers import load_pretrained from models.adaptive_avgmax_pool import select_adaptive_pool2d -from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD +from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] diff --git a/models/genmobilenet.py b/models/genmobilenet.py index d562f568..3f2b78f1 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from models.helpers import load_pretrained from models.adaptive_avgmax_pool import SelectAdaptivePool2d from models.conv2d_same import sconv2d -from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small', diff --git a/models/inception_resnet_v2.py b/models/inception_resnet_v2.py index 175e4691..8768b593 100644 --- a/models/inception_resnet_v2.py +++ b/models/inception_resnet_v2.py @@ -7,8 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from models.helpers import load_pretrained from models.adaptive_avgmax_pool import * -from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD - +from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD default_cfgs = { 'inception_resnet_v2': { diff --git a/models/inception_v4.py b/models/inception_v4.py index 9e2175f6..6ed47b83 100644 --- a/models/inception_v4.py +++ b/models/inception_v4.py @@ -7,8 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from models.helpers import load_pretrained from models.adaptive_avgmax_pool import * -from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD - +from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD default_cfgs = { 'inception_v4': { diff --git a/models/resnet.py b/models/resnet.py index 514564ea..a63fb785 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -10,7 +10,7 @@ import torch.nn.functional as F import math from models.helpers import load_pretrained from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d'] diff --git a/models/senet.py b/models/senet.py index d3465be9..07c70d72 100644 --- a/models/senet.py +++ b/models/senet.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from models.helpers import load_pretrained from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', 'seresnext50_32x4d', 'seresnext101_32x4d'] diff --git a/train.py b/train.py index b9ac1891..d3f7aaf1 100644 --- a/train.py +++ b/train.py @@ -10,7 +10,7 @@ try: except ImportError: has_apex = False -from data import * +from data import Dataset, create_loader, resolve_data_config from models import create_model, resume_checkpoint from utils import * from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy @@ -224,7 +224,7 @@ def main(): use_prefetcher=True, rand_erase_prob=args.reprob, rand_erase_pp=args.repp, - interpolation=data_config['interpolation'], + interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers,