Add support to split random erasing blocks into randomly selected number with --recount arg. Fix random selection of aspect ratios.

pull/31/head
Ross Wightman 5 years ago
parent 6946281fde
commit 66634d2200

@ -20,6 +20,7 @@ class PrefetchLoader:
loader, loader,
rand_erase_prob=0., rand_erase_prob=0.,
rand_erase_mode='const', rand_erase_mode='const',
rand_erase_count=1,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
fp16=False): fp16=False):
@ -32,7 +33,7 @@ class PrefetchLoader:
self.std = self.std.half() self.std = self.std.half()
if rand_erase_prob > 0.: if rand_erase_prob > 0.:
self.random_erasing = RandomErasing( self.random_erasing = RandomErasing(
probability=rand_erase_prob, mode=rand_erase_mode) probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
else: else:
self.random_erasing = None self.random_erasing = None
@ -94,6 +95,7 @@ def create_loader(
use_prefetcher=True, use_prefetcher=True,
rand_erase_prob=0., rand_erase_prob=0.,
rand_erase_mode='const', rand_erase_mode='const',
rand_erase_count=1,
color_jitter=0.4, color_jitter=0.4,
interpolation='bilinear', interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
@ -160,6 +162,7 @@ def create_loader(
loader, loader,
rand_erase_prob=rand_erase_prob if is_training else 0., rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_mode=rand_erase_mode, rand_erase_mode=rand_erase_mode,
rand_erase_count=rand_erase_count,
mean=mean, mean=mean,
std=std, std=std,
fp16=fp16) fp16=fp16)

@ -33,17 +33,20 @@ class RandomErasing:
'const' - erase block is constant color of 0 for all channels 'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-cannel random (normal) color 'rand' - erase block is same per-cannel random (normal) color
'pixel' - erase block is per-pixel random (normal) color 'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
""" """
def __init__( def __init__(
self, self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
mode='const', device='cuda'): mode='const', max_count=1, device='cuda'):
self.probability = probability self.probability = probability
self.sl = sl self.sl = sl
self.sh = sh self.sh = sh
self.min_aspect = min_aspect self.min_aspect = min_aspect
self.max_count = 8 self.min_count = 1
self.max_count = max_count
mode = mode.lower() mode = mode.lower()
self.rand_color = False self.rand_color = False
self.per_pixel = False self.per_pixel = False
@ -59,11 +62,13 @@ class RandomErasing:
if random.random() > self.probability: if random.random() > self.probability:
return return
area = img_h * img_w area = img_h * img_w
count = random.randint(1, self.max_count) count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
for _ in range(count): for _ in range(count):
for attempt in range(10): for attempt in range(10):
target_area = random.uniform(self.sl / count, self.sh / count) * area target_area = random.uniform(self.sl, self.sh) * area / count
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
aspect_ratio = math.exp(random.uniform(*log_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h: if w < img_w and h < img_h:

@ -107,24 +107,31 @@ class RandomResizedCropAndInterpolation(object):
for attempt in range(10): for attempt in range(10):
target_area = random.uniform(*scale) * area target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio) log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
w, h = h, w
if w <= img.size[0] and h <= img.size[1]: if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h) i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w) j = random.randint(0, img.size[0] - w)
return i, j, h, w return i, j, h, w
# Fallback # Fallback to central crop
w = min(img.size[0], img.size[1]) in_ratio = img.size[0] / img.size[1]
i = (img.size[1] - w) // 2 if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2 j = (img.size[0] - w) // 2
return i, j, w, w return i, j, h, w
def __call__(self, img): def __call__(self, img):
""" """

@ -91,6 +91,8 @@ parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")') help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--mixup', type=float, default=0.0, parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)') help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
@ -273,6 +275,7 @@ def main():
use_prefetcher=args.prefetcher, use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob, rand_erase_prob=args.reprob,
rand_erase_mode=args.remode, rand_erase_mode=args.remode,
rand_erase_count=args.recount,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'], mean=data_config['mean'],

Loading…
Cancel
Save