Merge remote-tracking branch 'origin/re-exp' into opt

pull/31/head
Ross Wightman 5 years ago
commit f37e633e9b

@ -20,6 +20,7 @@ class PrefetchLoader:
loader,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False):
@ -32,7 +33,7 @@ class PrefetchLoader:
self.std = self.std.half()
if rand_erase_prob > 0.:
self.random_erasing = RandomErasing(
probability=rand_erase_prob, mode=rand_erase_mode)
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
else:
self.random_erasing = None
@ -135,6 +136,7 @@ def create_loader(
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
color_jitter=0.4,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
@ -184,6 +186,7 @@ def create_loader(
loader,
rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_mode=rand_erase_mode,
rand_erase_count=rand_erase_count,
mean=mean,
std=std,
fp16=fp16)

@ -33,16 +33,20 @@ class RandomErasing:
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-cannel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""
def __init__(
self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
mode='const', device='cuda'):
mode='const', max_count=1, device='cuda'):
self.probability = probability
self.sl = sl
self.sh = sh
self.min_aspect = min_aspect
self.min_count = 1
self.max_count = max_count
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
@ -58,18 +62,22 @@ class RandomErasing:
if random.random() > self.probability:
return
area = img_h * img_w
for attempt in range(100):
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
break
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
for _ in range(count):
for attempt in range(10):
target_area = random.uniform(self.sl, self.sh) * area / count
log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect))
aspect_ratio = math.exp(random.uniform(*log_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
break
def __call__(self, input):
if len(input.size()) == 3:

@ -107,24 +107,31 @@ class RandomResizedCropAndInterpolation(object):
for attempt in range(10):
target_area = random.uniform(*scale) * area
aspect_ratio = random.uniform(*ratio)
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
i = random.randint(0, img.size[1] - h)
j = random.randint(0, img.size[0] - w)
return i, j, h, w
# Fallback
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
# Fallback to central crop
in_ratio = img.size[0] / img.size[1]
if in_ratio < min(ratio):
w = img.size[0]
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = img.size[1]
w = int(round(h * max(ratio)))
else: # whole image
w = img.size[0]
h = img.size[1]
i = (img.size[1] - h) // 2
j = (img.size[0] - w) // 2
return i, j, w, w
return i, j, h, w
def __call__(self, img):
"""

@ -91,6 +91,8 @@ parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const',
help='Random erase mode (default: "const")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
@ -273,6 +275,7 @@ def main():
use_prefetcher=args.prefetcher,
rand_erase_prob=args.reprob,
rand_erase_mode=args.remode,
rand_erase_count=args.recount,
color_jitter=args.color_jitter,
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
mean=data_config['mean'],

Loading…
Cancel
Save