From 45cde6f0c79a192612505def739bedb435fbeb73 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 Mar 2019 22:17:42 -0700 Subject: [PATCH] Improve creation of data pipeline with prefetch enabled vs disabled, fixup inception_res_v2 and dpn models --- data/__init__.py | 4 +- data/dataset.py | 5 +- data/transforms.py | 101 ++++++++++++++++++++++++++++++---- data/utils.py | 55 ++++++++++++++++-- models/dpn.py | 22 +++++--- models/inception_resnet_v2.py | 27 ++++++--- models/model_factory.py | 2 +- train.py | 35 ++++++------ validate.py | 23 ++++---- 9 files changed, 210 insertions(+), 64 deletions(-) diff --git a/data/__init__.py b/data/__init__.py index 07868e03..e5289973 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,4 @@ from data.dataset import Dataset -from data.transforms import transforms_imagenet_eval, transforms_imagenet_train -from data.utils import fast_collate, PrefetchLoader +from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd +from data.utils import create_loader from data.random_erasing import RandomErasingTorch, RandomErasingNumpy \ No newline at end of file diff --git a/data/dataset.py b/data/dataset.py index e269e60f..acd3fc71 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -54,7 +54,7 @@ class Dataset(data.Dataset): def __init__( self, root, - transform): + transform=None): imgs, _, _ = find_images_and_targets(root) if len(imgs) == 0: @@ -67,7 +67,8 @@ class Dataset(data.Dataset): def __getitem__(self, index): path, target = self.imgs[index] img = Image.open(path).convert('RGB') - img = self.transform(img) + if self.transform is not None: + img = self.transform(img) if target is None: target = torch.zeros(1).long() return img, target diff --git a/data/transforms.py b/data/transforms.py index 80491ef6..f1222e2b 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -15,7 +15,38 @@ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] -class AsNumpy: +# FIXME replace these mean/std fn with model factory based values from config dict +def get_model_meanstd(model_name): + model_name = model_name.lower() + if 'dpn' in model_name: + return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD + elif 'ception' in model_name: + return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD + else: + return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def get_model_mean(model_name): + model_name = model_name.lower() + if 'dpn' in model_name: + return IMAGENET_DPN_STD + elif 'ception' in model_name: + return IMAGENET_INCEPTION_MEAN + else: + return IMAGENET_DEFAULT_MEAN + + +def get_model_std(model_name): + model_name = model_name.lower() + if 'dpn' in model_name: + return IMAGENET_DEFAULT_STD + elif 'ception' in model_name: + return IMAGENET_INCEPTION_STD + else: + return IMAGENET_DEFAULT_STD + + +class ToNumpy: def __call__(self, pil_img): np_img = np.array(pil_img, dtype=np.uint8) @@ -25,29 +56,79 @@ class AsNumpy: return np_img +class ToTensor: + + def __init__(self, dtype=torch.float32): + self.dtype = dtype + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return torch.from_numpy(np_img).to(dtype=self.dtype) + + def transforms_imagenet_train( img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4), - random_erasing=0.4): + random_erasing=0.4, + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD +): tfl = [ - transforms.RandomResizedCrop(img_size, scale=scale), + transforms.RandomResizedCrop( + img_size, scale=scale, interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.ColorJitter(*color_jitter), - AsNumpy(), ] - #if random_erasing > 0.: - # tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) + + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean) * 255, + std=torch.tensor(std) * 255) + ] + if random_erasing > 0.: + tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) return transforms.Compose(tfl) -def transforms_imagenet_eval(img_size=224, crop_pct=None): +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): crop_pct = crop_pct or DEFAULT_CROP_PCT scale_size = int(math.floor(img_size / crop_pct)) - return transforms.Compose([ + tfl = [ transforms.Resize(scale_size, Image.BICUBIC), transforms.CenterCrop(img_size), - AsNumpy(), - ]) + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + # tfl += [ + # ToTensor(), + # transforms.Normalize( + # mean=torch.tensor(mean) * 255, + # std=torch.tensor(std) * 255) + # ] + + return transforms.Compose(tfl) diff --git a/data/utils.py b/data/utils.py index f4afa60e..bda111f3 100644 --- a/data/utils.py +++ b/data/utils.py @@ -1,5 +1,7 @@ import torch +import torch.utils.data as tdata from data.random_erasing import RandomErasingTorch +from data.transforms import * def fast_collate(batch): @@ -17,16 +19,17 @@ class PrefetchLoader: def __init__(self, loader, fp16=False, - random_erasing=True, - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]): + random_erasing=0., + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): self.loader = loader self.fp16 = fp16 self.random_erasing = random_erasing self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) if random_erasing: - self.random_erasing = RandomErasingTorch(per_pixel=True) + self.random_erasing = RandomErasingTorch( + probability=random_erasing, per_pixel=True) else: self.random_erasing = None @@ -63,3 +66,47 @@ class PrefetchLoader: def __len__(self): return len(self.loader) + + +def create_loader( + dataset, + img_size, + batch_size, + is_training=False, + use_prefetcher=True, + random_erasing=0., + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=1, +): + + if is_training: + transform = transforms_imagenet_train( + img_size, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + else: + transform = transforms_imagenet_eval( + img_size, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + + dataset.transform = transform + + loader = tdata.DataLoader( + dataset, + batch_size=batch_size, + shuffle=is_training, + num_workers=num_workers, + collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate, + ) + if use_prefetcher: + loader = PrefetchLoader( + loader, + random_erasing=random_erasing if is_training else 0., + mean=mean, + std=std) + + return loader diff --git a/models/dpn.py b/models/dpn.py index 57e48d3b..e8f84da2 100644 --- a/models/dpn.py +++ b/models/dpn.py @@ -21,15 +21,23 @@ from .adaptive_avgmax_pool import adaptive_avgmax_pool2d __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] -# If anyone able to provide direct link hosting, more than happy to fill these out.. -rwightman model_urls = { - 'dpn68': '', - 'dpn68b_extra': 'dpn68_extra-87733ef7.pth', + 'dpn68': + 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth', + 'dpn68b_extra': + 'http://data.lip6.fr/cadene/pretrainedmodels/' + 'dpn68b_extra-84854c156.pth', 'dpn92': '', - 'dpn92_extra': '', - 'dpn98': '', - 'dpn131': 'dpn131-89380fa2.pth', - 'dpn107_extra': 'dpn107_extra-fc014e8ec.pth' + 'dpn92_extra': + 'http://data.lip6.fr/cadene/pretrainedmodels/' + 'dpn92_extra-b040e4a9b.pth', + 'dpn98': + 'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth', + 'dpn131': + 'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth', + 'dpn107_extra': + 'http://data.lip6.fr/cadene/pretrainedmodels/' + 'dpn107_extra-1ac7121e2.pth' } diff --git a/models/inception_resnet_v2.py b/models/inception_resnet_v2.py index d364a3b6..c8bcf208 100644 --- a/models/inception_resnet_v2.py +++ b/models/inception_resnet_v2.py @@ -10,7 +10,7 @@ import numpy as np from .adaptive_avgmax_pool import * model_urls = { - 'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionresnetv2-d579a627.pth' + 'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth' } @@ -267,7 +267,7 @@ class InceptionResnetV2(nn.Module): self.block8 = Block8(noReLU=True) self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) self.num_features = 1536 - self.classif = nn.Linear(1536, num_classes) + self.last_linear = nn.Linear(1536, num_classes) def get_classifier(self): return self.classif @@ -277,9 +277,16 @@ class InceptionResnetV2(nn.Module): self.num_classes = num_classes del self.classif if num_classes: - self.classif = torch.nn.Linear(1536, num_classes) + self.last_linear = torch.nn.Linear(1536, num_classes) else: - self.classif = None + self.last_linear = None + + def trim_classifier(self, trim=1): + self.num_classes -= trim + new_last_linear = nn.Linear(1536, self.num_classes) + new_last_linear.weight.data = self.last_linear.weight.data[trim:] + new_last_linear.bias.data = self.last_linear.bias.data[trim:] + self.last_linear = new_last_linear def forward_features(self, x, pool=True): x = self.conv2d_1a(x) @@ -298,7 +305,8 @@ class InceptionResnetV2(nn.Module): x = self.block8(x) x = self.conv2d_7b(x) if pool: - x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) + x = adaptive_avgmax_pool2d(x, self.global_pool) + #x = F.avg_pool2d(x, 8, count_include_pad=False) x = x.view(x.size(0), -1) return x @@ -306,20 +314,23 @@ class InceptionResnetV2(nn.Module): x = self.forward_features(x, pool=True) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classif(x) + x = self.last_linear(x) return x -def inception_resnet_v2(pretrained=False, num_classes=1001, **kwargs): +def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs): r"""InceptionResnetV2 model architecture from the `"InceptionV4, Inception-ResNet..." `_ paper. Args: pretrained ('string'): If True, returns a model pre-trained on ImageNet """ - model = InceptionResnetV2(num_classes=num_classes, **kwargs) + extra_class = 1 if pretrained else 0 + model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs) if pretrained: print('Loading pretrained from %s' % model_urls['imagenet']) model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) + model.trim_classifier() + return model diff --git a/models/model_factory.py b/models/model_factory.py index 813ff43a..a40c7638 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -44,7 +44,7 @@ model_config_dict = { 'dpn68b_extra': { 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, 'inception_resnet_v2': { - 'model_name': 'inception_resnet_v2', 'num_classes': 1001, 'input_size': 299, 'normalizer': 'le'}, + 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, } diff --git a/train.py b/train.py index 90cabead..195dfee8 100644 --- a/train.py +++ b/train.py @@ -93,34 +93,33 @@ def main(): batch_size = args.batch_size torch.manual_seed(args.seed) - dataset_train = Dataset( - os.path.join(args.data, 'train'), - transform=transforms_imagenet_train()) + data_mean, data_std = get_model_meanstd(args.model) - loader_train = data.DataLoader( + dataset_train = Dataset(os.path.join(args.data, 'train')) + + loader_train = create_loader( dataset_train, + img_size=args.img_size, batch_size=batch_size, - shuffle=True, + is_training=True, + use_prefetcher=True, + random_erasing=0.5, + mean=data_mean, + std=data_std, num_workers=args.workers, - collate_fn=fast_collate ) - loader_train = PrefetchLoader( - loader_train, random_erasing=True, - ) - dataset_eval = Dataset( - os.path.join(args.data, 'validation'), - transform=transforms_imagenet_eval()) + dataset_eval = Dataset(os.path.join(args.data, 'validation')) - loader_eval = data.DataLoader( + loader_eval = create_loader( dataset_eval, + img_size=args.img_size, batch_size=4 * args.batch_size, - shuffle=False, + is_training=False, + use_prefetcher=True, + mean=data_mean, + std=data_std, num_workers=args.workers, - collate_fn=fast_collate, - ) - loader_eval = PrefetchLoader( - loader_eval, random_erasing=False, ) model = model_factory.create_model( diff --git a/validate.py b/validate.py index e08e1d95..f1b6df1b 100644 --- a/validate.py +++ b/validate.py @@ -9,11 +9,9 @@ import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.parallel -import torch.utils.data as data - -from models import create_model, transforms_imagenet_eval -from dataset import Dataset +from models import create_model +from data import Dataset, create_loader, get_model_meanstd parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') @@ -80,14 +78,15 @@ def main(): cudnn.benchmark = True - dataset = Dataset( - args.data, - transforms_imagenet_eval(args.model, args.img_size)) - - loader = data.DataLoader( - dataset, - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True) + data_mean, data_std = get_model_meanstd(args.model) + loader = create_loader( + Dataset(args.data), + img_size=args.img_size, + batch_size=args.batch_size, + use_prefetcher=True, + mean=data_mean, + std=data_std, + num_workers=args.workers) batch_time = AverageMeter() losses = AverageMeter()