Add some Nvidia performance enhancements (prefetch loader, fast collate), and refactor some of training and model fact/transforms
parent
9d927a389a
commit
2295cf56c2
@ -0,0 +1,4 @@
|
|||||||
|
from data.dataset import Dataset
|
||||||
|
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train
|
||||||
|
from data.utils import fast_collate, PrefetchLoader
|
||||||
|
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
@ -0,0 +1,131 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
#from torchvision.transforms import *
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasingNumpy:
|
||||||
|
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||||
|
'Random Erasing Data Augmentation' by Zhong et al.
|
||||||
|
See https://arxiv.org/pdf/1708.04896.pdf
|
||||||
|
|
||||||
|
This 'Numpy' variant of RandomErasing is intended to be applied on a per
|
||||||
|
image basis after transforming the image to uint8 numpy array in
|
||||||
|
range 0-255 prior to tensor conversion and normalization
|
||||||
|
Args:
|
||||||
|
probability: The probability that the Random Erasing operation will be performed.
|
||||||
|
sl: Minimum proportion of erased area against input image.
|
||||||
|
sh: Maximum proportion of erased area against input image.
|
||||||
|
r1: Minimum aspect ratio of erased area.
|
||||||
|
mean: Erasing value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||||
|
per_pixel=False, rand_color=False,
|
||||||
|
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
|
||||||
|
out_type=np.uint8):
|
||||||
|
self.probability = probability
|
||||||
|
if not per_pixel and not rand_color:
|
||||||
|
self.mean = np.array(mean).round().astype(out_type)
|
||||||
|
else:
|
||||||
|
self.mean = None
|
||||||
|
self.sl = sl
|
||||||
|
self.sh = sh
|
||||||
|
self.min_aspect = min_aspect
|
||||||
|
self.pl = pl
|
||||||
|
self.ph = ph
|
||||||
|
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||||
|
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||||
|
self.out_type = out_type
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() > self.probability:
|
||||||
|
return img
|
||||||
|
|
||||||
|
chan, img_h, img_w = img.shape
|
||||||
|
area = img_h * img_w
|
||||||
|
for attempt in range(100):
|
||||||
|
target_area = random.uniform(self.sl, self.sh) * area
|
||||||
|
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||||
|
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if self.rand_color:
|
||||||
|
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
|
||||||
|
elif not self.per_pixel:
|
||||||
|
c = self.mean[:chan]
|
||||||
|
if w < img_w and h < img_h:
|
||||||
|
top = random.randint(0, img_h - h)
|
||||||
|
left = random.randint(0, img_w - w)
|
||||||
|
if self.per_pixel:
|
||||||
|
img[:, top:top + h, left:left + w] = np.random.randint(
|
||||||
|
self.pl, self.ph + 1, (chan, h, w), self.out_type)
|
||||||
|
else:
|
||||||
|
img[:, top:top + h, left:left + w] = c
|
||||||
|
return img
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasingTorch:
|
||||||
|
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||||
|
'Random Erasing Data Augmentation' by Zhong et al.
|
||||||
|
See https://arxiv.org/pdf/1708.04896.pdf
|
||||||
|
|
||||||
|
This 'Torch' variant of RandomErasing is intended to be applied to a full batch
|
||||||
|
tensor after it has been normalized by dataset mean and std.
|
||||||
|
Args:
|
||||||
|
probability: The probability that the Random Erasing operation will be performed.
|
||||||
|
sl: Minimum proportion of erased area against input image.
|
||||||
|
sh: Maximum proportion of erased area against input image.
|
||||||
|
r1: Minimum aspect ratio of erased area.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||||
|
per_pixel=False, rand_color=False,
|
||||||
|
device='cuda'):
|
||||||
|
self.probability = probability
|
||||||
|
self.sl = sl
|
||||||
|
self.sh = sh
|
||||||
|
self.min_aspect = min_aspect
|
||||||
|
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||||
|
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __call__(self, batch):
|
||||||
|
batch_size, chan, img_h, img_w = batch.size()
|
||||||
|
area = img_h * img_w
|
||||||
|
for i in range(batch_size):
|
||||||
|
if random.random() > self.probability:
|
||||||
|
continue
|
||||||
|
img = batch[i]
|
||||||
|
for attempt in range(100):
|
||||||
|
target_area = random.uniform(self.sl, self.sh) * area
|
||||||
|
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||||
|
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
if self.rand_color:
|
||||||
|
c = torch.empty(chan, dtype=batch.dtype, device=self.device).normal_()
|
||||||
|
elif not self.per_pixel:
|
||||||
|
c = torch.zeros(chan, dtype=batch.dtype, device=self.device)
|
||||||
|
if w < img_w and h < img_h:
|
||||||
|
top = random.randint(0, img_h - h)
|
||||||
|
left = random.randint(0, img_w - w)
|
||||||
|
if self.per_pixel:
|
||||||
|
img[:, top:top + h, left:left + w] = torch.empty(
|
||||||
|
(chan, h, w), dtype=batch.dtype, device=self.device).normal_()
|
||||||
|
else:
|
||||||
|
img[:, top:top + h, left:left + w] = c
|
||||||
|
break
|
||||||
|
|
||||||
|
return batch
|
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from data.random_erasing import RandomErasingNumpy
|
||||||
|
|
||||||
|
DEFAULT_CROP_PCT = 0.875
|
||||||
|
|
||||||
|
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
||||||
|
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
||||||
|
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5]
|
||||||
|
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5]
|
||||||
|
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||||
|
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
|
||||||
|
class AsNumpy:
|
||||||
|
|
||||||
|
def __call__(self, pil_img):
|
||||||
|
np_img = np.array(pil_img, dtype=np.uint8)
|
||||||
|
if np_img.ndim < 3:
|
||||||
|
np_img = np.expand_dims(np_img, axis=-1)
|
||||||
|
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_train(
|
||||||
|
img_size=224,
|
||||||
|
scale=(0.1, 1.0),
|
||||||
|
color_jitter=(0.4, 0.4, 0.4),
|
||||||
|
random_erasing=0.4):
|
||||||
|
|
||||||
|
tfl = [
|
||||||
|
transforms.RandomResizedCrop(img_size, scale=scale),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ColorJitter(*color_jitter),
|
||||||
|
AsNumpy(),
|
||||||
|
]
|
||||||
|
#if random_erasing > 0.:
|
||||||
|
# tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||||
|
return transforms.Compose(tfl)
|
||||||
|
|
||||||
|
|
||||||
|
def transforms_imagenet_eval(img_size=224, crop_pct=None):
|
||||||
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||||
|
scale_size = int(math.floor(img_size / crop_pct))
|
||||||
|
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize(scale_size, Image.BICUBIC),
|
||||||
|
transforms.CenterCrop(img_size),
|
||||||
|
AsNumpy(),
|
||||||
|
])
|
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
from data.random_erasing import RandomErasingTorch
|
||||||
|
|
||||||
|
|
||||||
|
def fast_collate(batch):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
batch_size = len(targets)
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i] += torch.from_numpy(batch[i][0])
|
||||||
|
|
||||||
|
return tensor, targets
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchLoader:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
loader,
|
||||||
|
fp16=False,
|
||||||
|
random_erasing=True,
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]):
|
||||||
|
self.loader = loader
|
||||||
|
self.fp16 = fp16
|
||||||
|
self.random_erasing = random_erasing
|
||||||
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||||
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||||
|
if random_erasing:
|
||||||
|
self.random_erasing = RandomErasingTorch(per_pixel=True)
|
||||||
|
else:
|
||||||
|
self.random_erasing = None
|
||||||
|
|
||||||
|
if self.fp16:
|
||||||
|
self.mean = self.mean.half()
|
||||||
|
self.std = self.std.half()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for next_input, next_target in self.loader:
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
next_input = next_input.cuda(non_blocking=True)
|
||||||
|
next_target = next_target.cuda(non_blocking=True)
|
||||||
|
if self.fp16:
|
||||||
|
next_input = next_input.half()
|
||||||
|
else:
|
||||||
|
next_input = next_input.float()
|
||||||
|
next_input = next_input.sub_(self.mean).div_(self.std)
|
||||||
|
if self.random_erasing is not None:
|
||||||
|
next_input = self.random_erasing(next_input)
|
||||||
|
|
||||||
|
if not first:
|
||||||
|
yield input, target
|
||||||
|
else:
|
||||||
|
first = False
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
input = next_input
|
||||||
|
target = next_target
|
||||||
|
|
||||||
|
yield input, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loader)
|
@ -1,2 +1,2 @@
|
|||||||
from .model_factory import create_model
|
from .model_factory import create_model
|
||||||
from .transforms import transforms_imagenet_eval, transforms_imagenet_train
|
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
from torchvision.transforms import *
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import random
|
|
||||||
import math
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class RandomErasing:
|
|
||||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
|
||||||
'Random Erasing Data Augmentation' by Zhong et al.
|
|
||||||
See https://arxiv.org/pdf/1708.04896.pdf
|
|
||||||
Args:
|
|
||||||
probability: The probability that the Random Erasing operation will be performed.
|
|
||||||
sl: Minimum proportion of erased area against input image.
|
|
||||||
sh: Maximum proportion of erased area against input image.
|
|
||||||
r1: Minimum aspect ratio of erased area.
|
|
||||||
mean: Erasing value.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
|
||||||
per_pixel=False, random=False,
|
|
||||||
pl=0, ph=1., mean=[0.485, 0.456, 0.406]):
|
|
||||||
self.probability = probability
|
|
||||||
self.mean = torch.tensor(mean)
|
|
||||||
self.sl = sl
|
|
||||||
self.sh = sh
|
|
||||||
self.min_aspect = min_aspect
|
|
||||||
self.pl = pl
|
|
||||||
self.ph = ph
|
|
||||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
|
||||||
self.random = random # per block random, bounded by [pl, ph]
|
|
||||||
|
|
||||||
def __call__(self, img):
|
|
||||||
if random.random() > self.probability:
|
|
||||||
return img
|
|
||||||
|
|
||||||
chan, img_h, img_w = img.size()
|
|
||||||
area = img_h * img_w
|
|
||||||
for attempt in range(100):
|
|
||||||
target_area = random.uniform(self.sl, self.sh) * area
|
|
||||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
|
||||||
|
|
||||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
|
||||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
|
||||||
c = torch.empty((chan)).uniform_(self.pl, self.ph) if self.random else self.mean[:chan]
|
|
||||||
if w < img_w and h < img_h:
|
|
||||||
top = random.randint(0, img_h - h)
|
|
||||||
left = random.randint(0, img_w - w)
|
|
||||||
if self.per_pixel:
|
|
||||||
img[:, top:top + h, left:left + w] = torch.empty((chan, h, w)).uniform_(self.pl, self.ph)
|
|
||||||
else:
|
|
||||||
img[:, top:top + h, left:left + w] = c
|
|
||||||
return img
|
|
||||||
|
|
||||||
return img
|
|
@ -1,80 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
|
||||||
import math
|
|
||||||
from models.random_erasing import RandomErasing
|
|
||||||
|
|
||||||
DEFAULT_CROP_PCT = 0.875
|
|
||||||
|
|
||||||
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
|
||||||
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
|
||||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
|
||||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
|
||||||
|
|
||||||
|
|
||||||
class LeNormalize(object):
|
|
||||||
"""Normalize to -1..1 in Google Inception style
|
|
||||||
"""
|
|
||||||
def __call__(self, tensor):
|
|
||||||
for t in tensor:
|
|
||||||
t.sub_(0.5).mul_(2.0)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(
|
|
||||||
model_name,
|
|
||||||
img_size=224,
|
|
||||||
scale=(0.1, 1.0),
|
|
||||||
color_jitter=(0.4, 0.4, 0.4),
|
|
||||||
random_erasing=0.4):
|
|
||||||
if 'dpn' in model_name:
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DPN_MEAN,
|
|
||||||
std=IMAGENET_DPN_STD)
|
|
||||||
elif 'inception' in model_name:
|
|
||||||
normalize = LeNormalize()
|
|
||||||
else:
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD)
|
|
||||||
|
|
||||||
tfl = [
|
|
||||||
transforms.RandomResizedCrop(img_size, scale=scale),
|
|
||||||
transforms.RandomHorizontalFlip(),
|
|
||||||
transforms.ColorJitter(*color_jitter),
|
|
||||||
transforms.ToTensor()]
|
|
||||||
if random_erasing > 0.:
|
|
||||||
tfl.append(RandomErasing(random_erasing, per_pixel=True))
|
|
||||||
return transforms.Compose(tfl + [normalize])
|
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None):
|
|
||||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
|
||||||
if 'dpn' in model_name:
|
|
||||||
if crop_pct is None:
|
|
||||||
# Use default 87.5% crop for model's native img_size
|
|
||||||
# but use 100% crop for larger than native as it
|
|
||||||
# improves test time results across all models.
|
|
||||||
if img_size == 224:
|
|
||||||
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
|
|
||||||
else:
|
|
||||||
scale_size = img_size
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DPN_MEAN,
|
|
||||||
std=IMAGENET_DPN_STD)
|
|
||||||
elif 'inception' in model_name:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = LeNormalize()
|
|
||||||
else:
|
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
normalize = transforms.Normalize(
|
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
|
||||||
std=IMAGENET_DEFAULT_STD)
|
|
||||||
|
|
||||||
return transforms.Compose([
|
|
||||||
transforms.Resize(scale_size, Image.BICUBIC),
|
|
||||||
transforms.CenterCrop(img_size),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
normalize])
|
|
@ -0,0 +1,2 @@
|
|||||||
|
from optim.adabound import AdaBound
|
||||||
|
from optim.nadam import Nadam
|
Loading…
Reference in new issue