Random erasing crash fix and args pass through

pull/1/head
Ross Wightman 6 years ago
parent 9c3859fb9c
commit c328b155e9

@ -18,25 +18,25 @@ class PrefetchLoader:
def __init__(self, def __init__(self,
loader, loader,
random_erasing=0., rand_erase_prob=0.,
rand_erase_pp=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD): std=IMAGENET_DEFAULT_STD):
self.loader = loader self.loader = loader
self.random_erasing = random_erasing self.stream = torch.cuda.Stream()
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if random_erasing: if rand_erase_prob:
self.random_erasing = RandomErasingTorch( self.random_erasing = RandomErasingTorch(
probability=random_erasing, per_pixel=False) probability=rand_erase_prob, per_pixel=rand_erase_pp)
else: else:
self.random_erasing = None self.random_erasing = None
def __iter__(self): def __iter__(self):
stream = torch.cuda.Stream()
first = True first = True
for next_input, next_target in self.loader: for next_input, next_target in self.loader:
with torch.cuda.stream(stream): with torch.cuda.stream(self.stream):
next_input = next_input.cuda(non_blocking=True) next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std) next_input = next_input.float().sub_(self.mean).div_(self.std)
@ -48,7 +48,7 @@ class PrefetchLoader:
else: else:
first = False first = False
torch.cuda.current_stream().wait_stream(stream) torch.cuda.current_stream().wait_stream(self.stream)
input = next_input input = next_input
target = next_target target = next_target
@ -68,7 +68,8 @@ def create_loader(
batch_size, batch_size,
is_training=False, is_training=False,
use_prefetcher=True, use_prefetcher=True,
random_erasing=0., rand_erase_prob=0.,
rand_erase_pp=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
num_workers=1, num_workers=1,
@ -110,7 +111,8 @@ def create_loader(
if use_prefetcher: if use_prefetcher:
loader = PrefetchLoader( loader = PrefetchLoader(
loader, loader,
random_erasing=random_erasing if is_training else 0., rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_pp=rand_erase_pp,
mean=mean, mean=mean,
std=std) std=std)

@ -110,7 +110,7 @@ class RandomErasingTorch:
h = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color: if self.rand_color:
c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_() c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
elif not self.per_pixel: elif not self.per_pixel:
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda() c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
if w < img_w and h < img_h: if w < img_w and h < img_h:
@ -118,7 +118,7 @@ class RandomErasingTorch:
left = random.randint(0, img_w - w) left = random.randint(0, img_w - w)
if self.per_pixel: if self.per_pixel:
img[:, top:top + h, left:left + w] = torch.empty( img[:, top:top + h, left:left + w] = torch.empty(
(chan, h, w), dtype=batch.dtype).cuda().normal_() (chan, h, w), dtype=batch.dtype).normal_().cuda()
else: else:
img[:, top:top + h, left:left + w] = c img[:, top:top + h, left:left + w] = c
break break

@ -61,6 +61,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"') help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)') help='Dropout rate (default: 0.1)')
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
help='Random erase prob (default: 0.4)')
parser.add_argument('--repp', action='store_true', default=False,
help='Random erase per-pixel (default: False)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)') help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
@ -196,7 +200,8 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
is_training=True, is_training=True,
use_prefetcher=True, use_prefetcher=True,
random_erasing=0.3, rand_erase_prob=args.reprob,
rand_erase_pp=args.repp,
mean=data_mean, mean=data_mean,
std=data_std, std=data_std,
num_workers=args.workers, num_workers=args.workers,

Loading…
Cancel
Save