diff --git a/data/loader.py b/data/loader.py index 23ae0c7c..f9849d30 100644 --- a/data/loader.py +++ b/data/loader.py @@ -18,25 +18,25 @@ class PrefetchLoader: def __init__(self, loader, - random_erasing=0., + rand_erase_prob=0., + rand_erase_pp=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): self.loader = loader - self.random_erasing = random_erasing + self.stream = torch.cuda.Stream() self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) - if random_erasing: + if rand_erase_prob: self.random_erasing = RandomErasingTorch( - probability=random_erasing, per_pixel=False) + probability=rand_erase_prob, per_pixel=rand_erase_pp) else: self.random_erasing = None def __iter__(self): - stream = torch.cuda.Stream() first = True for next_input, next_target in self.loader: - with torch.cuda.stream(stream): + with torch.cuda.stream(self.stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) next_input = next_input.float().sub_(self.mean).div_(self.std) @@ -48,7 +48,7 @@ class PrefetchLoader: else: first = False - torch.cuda.current_stream().wait_stream(stream) + torch.cuda.current_stream().wait_stream(self.stream) input = next_input target = next_target @@ -68,7 +68,8 @@ def create_loader( batch_size, is_training=False, use_prefetcher=True, - random_erasing=0., + rand_erase_prob=0., + rand_erase_pp=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=1, @@ -110,7 +111,8 @@ def create_loader( if use_prefetcher: loader = PrefetchLoader( loader, - random_erasing=random_erasing if is_training else 0., + rand_erase_prob=rand_erase_prob if is_training else 0., + rand_erase_pp=rand_erase_pp, mean=mean, std=std) diff --git a/data/random_erasing.py b/data/random_erasing.py index 81253311..668a3831 100644 --- a/data/random_erasing.py +++ b/data/random_erasing.py @@ -110,7 +110,7 @@ class RandomErasingTorch: h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if self.rand_color: - c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_() + c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda() elif not self.per_pixel: c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda() if w < img_w and h < img_h: @@ -118,7 +118,7 @@ class RandomErasingTorch: left = random.randint(0, img_w - w) if self.per_pixel: img[:, top:top + h, left:left + w] = torch.empty( - (chan, h, w), dtype=batch.dtype).cuda().normal_() + (chan, h, w), dtype=batch.dtype).normal_().cuda() else: img[:, top:top + h, left:left + w] = c break diff --git a/train.py b/train.py index d2931ca9..84a562ff 100644 --- a/train.py +++ b/train.py @@ -61,6 +61,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.1)') +parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT', + help='Random erase prob (default: 0.4)') +parser.add_argument('--repp', action='store_true', default=False, + help='Random erase per-pixel (default: False)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', @@ -196,7 +200,8 @@ def main(): batch_size=args.batch_size, is_training=True, use_prefetcher=True, - random_erasing=0.3, + rand_erase_prob=args.reprob, + rand_erase_pp=args.repp, mean=data_mean, std=data_std, num_workers=args.workers,