diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index e66f7b95..2a105128 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -43,6 +43,7 @@ class RandomErasing: self.sl = sl self.sh = sh self.min_aspect = min_aspect + self.max_count = 8 mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -58,18 +59,20 @@ class RandomErasing: if random.random() > self.probability: return area = img_h * img_w - for attempt in range(100): - target_area = random.uniform(self.sl, self.sh) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) - break + count = random.randint(1, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.sl / count, self.sh / count) * area + aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break def __call__(self, input): if len(input.size()) == 3: