diff --git a/timm/data/loader.py b/timm/data/loader.py index db902a06..2a416b31 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -20,6 +20,7 @@ class PrefetchLoader: loader, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, fp16=False): @@ -32,7 +33,7 @@ class PrefetchLoader: self.std = self.std.half() if rand_erase_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode) + probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) else: self.random_erasing = None @@ -135,6 +136,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, @@ -184,6 +186,7 @@ def create_loader( loader, rand_erase_prob=rand_erase_prob if is_training else 0., rand_erase_mode=rand_erase_mode, + rand_erase_count=rand_erase_count, mean=mean, std=std, fp16=fp16) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index e66f7b95..e944f22c 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -33,16 +33,20 @@ class RandomErasing: 'const' - erase block is constant color of 0 for all channels 'rand' - erase block is same per-cannel random (normal) color 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. """ def __init__( self, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - mode='const', device='cuda'): + mode='const', max_count=1, device='cuda'): self.probability = probability self.sl = sl self.sh = sh self.min_aspect = min_aspect + self.min_count = 1 + self.max_count = max_count mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -58,18 +62,22 @@ class RandomErasing: if random.random() > self.probability: return area = img_h * img_w - for attempt in range(100): - target_area = random.uniform(self.sl, self.sh) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) - break + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.sl, self.sh) * area / count + log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect)) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break def __call__(self, input): if len(input.size()) == 3: diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 13a6ff01..93796a04 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -107,24 +107,31 @@ class RandomResizedCropAndInterpolation(object): for attempt in range(10): target_area = random.uniform(*scale) * area - aspect_ratio = random.uniform(*ratio) + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) - if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio): - w, h = h, w - if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w - # Fallback - w = min(img.size[0], img.size[1]) - i = (img.size[1] - w) // 2 + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 - return i, j, w, w + return i, j, h, w def __call__(self, img): """ diff --git a/train.py b/train.py index 51006a0d..c795927d 100644 --- a/train.py +++ b/train.py @@ -91,6 +91,8 @@ parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') +parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', @@ -273,6 +275,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + rand_erase_count=args.recount, color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'],