diff --git a/timm/data/loader.py b/timm/data/loader.py index 1198d5e5..815a19da 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -20,6 +20,7 @@ class PrefetchLoader: loader, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, fp16=False): @@ -32,7 +33,7 @@ class PrefetchLoader: self.std = self.std.half() if rand_erase_prob > 0.: self.random_erasing = RandomErasing( - probability=rand_erase_prob, mode=rand_erase_mode) + probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count) else: self.random_erasing = None @@ -94,6 +95,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + rand_erase_count=1, color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, @@ -160,6 +162,7 @@ def create_loader( loader, rand_erase_prob=rand_erase_prob if is_training else 0., rand_erase_mode=rand_erase_mode, + rand_erase_count=rand_erase_count, mean=mean, std=std, fp16=fp16) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 2a105128..e944f22c 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -33,17 +33,20 @@ class RandomErasing: 'const' - erase block is constant color of 0 for all channels 'rand' - erase block is same per-cannel random (normal) color 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. """ def __init__( self, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - mode='const', device='cuda'): + mode='const', max_count=1, device='cuda'): self.probability = probability self.sl = sl self.sh = sh self.min_aspect = min_aspect - self.max_count = 8 + self.min_count = 1 + self.max_count = max_count mode = mode.lower() self.rand_color = False self.per_pixel = False @@ -59,11 +62,13 @@ class RandomErasing: if random.random() > self.probability: return area = img_h * img_w - count = random.randint(1, self.max_count) + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) for _ in range(count): for attempt in range(10): - target_area = random.uniform(self.sl / count, self.sh / count) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + target_area = random.uniform(self.sl, self.sh) * area / count + log_ratio = (math.log(self.min_aspect), math.log(1 / self.min_aspect)) + aspect_ratio = math.exp(random.uniform(*log_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 13a6ff01..93796a04 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -107,24 +107,31 @@ class RandomResizedCropAndInterpolation(object): for attempt in range(10): target_area = random.uniform(*scale) * area - aspect_ratio = random.uniform(*ratio) + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) - if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio): - w, h = h, w - if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w - # Fallback - w = min(img.size[0], img.size[1]) - i = (img.size[1] - w) // 2 + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 - return i, j, w, w + return i, j, h, w def __call__(self, img): """ diff --git a/train.py b/train.py index 51006a0d..c795927d 100644 --- a/train.py +++ b/train.py @@ -91,6 +91,8 @@ parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') +parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', @@ -273,6 +275,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + rand_erase_count=args.recount, color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'],