Improve creation of data pipeline with prefetch enabled vs disabled, fixup inception_res_v2 and dpn models

pull/1/head
Ross Wightman 6 years ago
parent 2295cf56c2
commit 45cde6f0c7

@ -1,4 +1,4 @@
from data.dataset import Dataset from data.dataset import Dataset
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd
from data.utils import fast_collate, PrefetchLoader from data.utils import create_loader
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy from data.random_erasing import RandomErasingTorch, RandomErasingNumpy

@ -54,7 +54,7 @@ class Dataset(data.Dataset):
def __init__( def __init__(
self, self,
root, root,
transform): transform=None):
imgs, _, _ = find_images_and_targets(root) imgs, _, _ = find_images_and_targets(root)
if len(imgs) == 0: if len(imgs) == 0:
@ -67,7 +67,8 @@ class Dataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
path, target = self.imgs[index] path, target = self.imgs[index]
img = Image.open(path).convert('RGB') img = Image.open(path).convert('RGB')
img = self.transform(img) if self.transform is not None:
img = self.transform(img)
if target is None: if target is None:
target = torch.zeros(1).long() target = torch.zeros(1).long()
return img, target return img, target

@ -15,7 +15,38 @@ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
class AsNumpy: # FIXME replace these mean/std fn with model factory based values from config dict
def get_model_meanstd(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
elif 'ception' in model_name:
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def get_model_mean(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DPN_STD
elif 'ception' in model_name:
return IMAGENET_INCEPTION_MEAN
else:
return IMAGENET_DEFAULT_MEAN
def get_model_std(model_name):
model_name = model_name.lower()
if 'dpn' in model_name:
return IMAGENET_DEFAULT_STD
elif 'ception' in model_name:
return IMAGENET_INCEPTION_STD
else:
return IMAGENET_DEFAULT_STD
class ToNumpy:
def __call__(self, pil_img): def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8) np_img = np.array(pil_img, dtype=np.uint8)
@ -25,29 +56,79 @@ class AsNumpy:
return np_img return np_img
class ToTensor:
def __init__(self, dtype=torch.float32):
self.dtype = dtype
def __call__(self, pil_img):
np_img = np.array(pil_img, dtype=np.uint8)
if np_img.ndim < 3:
np_img = np.expand_dims(np_img, axis=-1)
np_img = np.rollaxis(np_img, 2) # HWC to CHW
return torch.from_numpy(np_img).to(dtype=self.dtype)
def transforms_imagenet_train( def transforms_imagenet_train(
img_size=224, img_size=224,
scale=(0.1, 1.0), scale=(0.1, 1.0),
color_jitter=(0.4, 0.4, 0.4), color_jitter=(0.4, 0.4, 0.4),
random_erasing=0.4): random_erasing=0.4,
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD
):
tfl = [ tfl = [
transforms.RandomResizedCrop(img_size, scale=scale), transforms.RandomResizedCrop(
img_size, scale=scale, interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter), transforms.ColorJitter(*color_jitter),
AsNumpy(),
] ]
#if random_erasing > 0.:
# tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean) * 255,
std=torch.tensor(std) * 255)
]
if random_erasing > 0.:
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
return transforms.Compose(tfl) return transforms.Compose(tfl)
def transforms_imagenet_eval(img_size=224, crop_pct=None): def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT crop_pct = crop_pct or DEFAULT_CROP_PCT
scale_size = int(math.floor(img_size / crop_pct)) scale_size = int(math.floor(img_size / crop_pct))
return transforms.Compose([ tfl = [
transforms.Resize(scale_size, Image.BICUBIC), transforms.Resize(scale_size, Image.BICUBIC),
transforms.CenterCrop(img_size), transforms.CenterCrop(img_size),
AsNumpy(), ]
]) if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
# tfl += [
# ToTensor(),
# transforms.Normalize(
# mean=torch.tensor(mean) * 255,
# std=torch.tensor(std) * 255)
# ]
return transforms.Compose(tfl)

@ -1,5 +1,7 @@
import torch import torch
import torch.utils.data as tdata
from data.random_erasing import RandomErasingTorch from data.random_erasing import RandomErasingTorch
from data.transforms import *
def fast_collate(batch): def fast_collate(batch):
@ -17,16 +19,17 @@ class PrefetchLoader:
def __init__(self, def __init__(self,
loader, loader,
fp16=False, fp16=False,
random_erasing=True, random_erasing=0.,
mean=[0.485, 0.456, 0.406], mean=IMAGENET_DEFAULT_MEAN,
std=[0.229, 0.224, 0.225]): std=IMAGENET_DEFAULT_STD):
self.loader = loader self.loader = loader
self.fp16 = fp16 self.fp16 = fp16
self.random_erasing = random_erasing self.random_erasing = random_erasing
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if random_erasing: if random_erasing:
self.random_erasing = RandomErasingTorch(per_pixel=True) self.random_erasing = RandomErasingTorch(
probability=random_erasing, per_pixel=True)
else: else:
self.random_erasing = None self.random_erasing = None
@ -63,3 +66,47 @@ class PrefetchLoader:
def __len__(self): def __len__(self):
return len(self.loader) return len(self.loader)
def create_loader(
dataset,
img_size,
batch_size,
is_training=False,
use_prefetcher=True,
random_erasing=0.,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
):
if is_training:
transform = transforms_imagenet_train(
img_size,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
dataset.transform = transform
loader = tdata.DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_training,
num_workers=num_workers,
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
)
if use_prefetcher:
loader = PrefetchLoader(
loader,
random_erasing=random_erasing if is_training else 0.,
mean=mean,
std=std)
return loader

@ -21,15 +21,23 @@ from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
# If anyone able to provide direct link hosting, more than happy to fill these out.. -rwightman
model_urls = { model_urls = {
'dpn68': '', 'dpn68':
'dpn68b_extra': 'dpn68_extra-87733ef7.pth', 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth',
'dpn68b_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn68b_extra-84854c156.pth',
'dpn92': '', 'dpn92': '',
'dpn92_extra': '', 'dpn92_extra':
'dpn98': '', 'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn131': 'dpn131-89380fa2.pth', 'dpn92_extra-b040e4a9b.pth',
'dpn107_extra': 'dpn107_extra-fc014e8ec.pth' 'dpn98':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth',
'dpn131':
'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth',
'dpn107_extra':
'http://data.lip6.fr/cadene/pretrainedmodels/'
'dpn107_extra-1ac7121e2.pth'
} }

@ -10,7 +10,7 @@ import numpy as np
from .adaptive_avgmax_pool import * from .adaptive_avgmax_pool import *
model_urls = { model_urls = {
'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionresnetv2-d579a627.pth' 'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth'
} }
@ -267,7 +267,7 @@ class InceptionResnetV2(nn.Module):
self.block8 = Block8(noReLU=True) self.block8 = Block8(noReLU=True)
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
self.num_features = 1536 self.num_features = 1536
self.classif = nn.Linear(1536, num_classes) self.last_linear = nn.Linear(1536, num_classes)
def get_classifier(self): def get_classifier(self):
return self.classif return self.classif
@ -277,9 +277,16 @@ class InceptionResnetV2(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
del self.classif del self.classif
if num_classes: if num_classes:
self.classif = torch.nn.Linear(1536, num_classes) self.last_linear = torch.nn.Linear(1536, num_classes)
else: else:
self.classif = None self.last_linear = None
def trim_classifier(self, trim=1):
self.num_classes -= trim
new_last_linear = nn.Linear(1536, self.num_classes)
new_last_linear.weight.data = self.last_linear.weight.data[trim:]
new_last_linear.bias.data = self.last_linear.bias.data[trim:]
self.last_linear = new_last_linear
def forward_features(self, x, pool=True): def forward_features(self, x, pool=True):
x = self.conv2d_1a(x) x = self.conv2d_1a(x)
@ -298,7 +305,8 @@ class InceptionResnetV2(nn.Module):
x = self.block8(x) x = self.block8(x)
x = self.conv2d_7b(x) x = self.conv2d_7b(x)
if pool: if pool:
x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) x = adaptive_avgmax_pool2d(x, self.global_pool)
#x = F.avg_pool2d(x, 8, count_include_pad=False)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x
@ -306,20 +314,23 @@ class InceptionResnetV2(nn.Module):
x = self.forward_features(x, pool=True) x = self.forward_features(x, pool=True)
if self.drop_rate > 0: if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classif(x) x = self.last_linear(x)
return x return x
def inception_resnet_v2(pretrained=False, num_classes=1001, **kwargs): def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs):
r"""InceptionResnetV2 model architecture from the r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper. `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
Args: Args:
pretrained ('string'): If True, returns a model pre-trained on ImageNet pretrained ('string'): If True, returns a model pre-trained on ImageNet
""" """
model = InceptionResnetV2(num_classes=num_classes, **kwargs) extra_class = 1 if pretrained else 0
model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs)
if pretrained: if pretrained:
print('Loading pretrained from %s' % model_urls['imagenet']) print('Loading pretrained from %s' % model_urls['imagenet'])
model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
model.trim_classifier()
return model return model

@ -44,7 +44,7 @@ model_config_dict = {
'dpn68b_extra': { 'dpn68b_extra': {
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
'inception_resnet_v2': { 'inception_resnet_v2': {
'model_name': 'inception_resnet_v2', 'num_classes': 1001, 'input_size': 299, 'normalizer': 'le'}, 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
} }

@ -93,34 +93,33 @@ def main():
batch_size = args.batch_size batch_size = args.batch_size
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
dataset_train = Dataset( data_mean, data_std = get_model_meanstd(args.model)
os.path.join(args.data, 'train'),
transform=transforms_imagenet_train())
loader_train = data.DataLoader( dataset_train = Dataset(os.path.join(args.data, 'train'))
loader_train = create_loader(
dataset_train, dataset_train,
img_size=args.img_size,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, is_training=True,
use_prefetcher=True,
random_erasing=0.5,
mean=data_mean,
std=data_std,
num_workers=args.workers, num_workers=args.workers,
collate_fn=fast_collate
) )
loader_train = PrefetchLoader(
loader_train, random_erasing=True,
)
dataset_eval = Dataset( dataset_eval = Dataset(os.path.join(args.data, 'validation'))
os.path.join(args.data, 'validation'),
transform=transforms_imagenet_eval())
loader_eval = data.DataLoader( loader_eval = create_loader(
dataset_eval, dataset_eval,
img_size=args.img_size,
batch_size=4 * args.batch_size, batch_size=4 * args.batch_size,
shuffle=False, is_training=False,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers, num_workers=args.workers,
collate_fn=fast_collate,
)
loader_eval = PrefetchLoader(
loader_eval, random_erasing=False,
) )
model = model_factory.create_model( model = model_factory.create_model(

@ -9,11 +9,9 @@ import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
import torch.utils.data as data
from models import create_model
from models import create_model, transforms_imagenet_eval from data import Dataset, create_loader, get_model_meanstd
from dataset import Dataset
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
@ -80,14 +78,15 @@ def main():
cudnn.benchmark = True cudnn.benchmark = True
dataset = Dataset( data_mean, data_std = get_model_meanstd(args.model)
args.data, loader = create_loader(
transforms_imagenet_eval(args.model, args.img_size)) Dataset(args.data),
img_size=args.img_size,
loader = data.DataLoader( batch_size=args.batch_size,
dataset, use_prefetcher=True,
batch_size=args.batch_size, shuffle=False, mean=data_mean,
num_workers=args.workers, pin_memory=True) std=data_std,
num_workers=args.workers)
batch_time = AverageMeter() batch_time = AverageMeter()
losses = AverageMeter() losses = AverageMeter()

Loading…
Cancel
Save