Random erasing crash fix and args pass through

pull/1/head
Ross Wightman 5 years ago
parent 9c3859fb9c
commit c328b155e9

@ -18,25 +18,25 @@ class PrefetchLoader:
def __init__(self,
loader,
random_erasing=0.,
rand_erase_prob=0.,
rand_erase_pp=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
self.loader = loader
self.random_erasing = random_erasing
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if random_erasing:
if rand_erase_prob:
self.random_erasing = RandomErasingTorch(
probability=random_erasing, per_pixel=False)
probability=rand_erase_prob, per_pixel=rand_erase_pp)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
with torch.cuda.stream(self.stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std)
@ -48,7 +48,7 @@ class PrefetchLoader:
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.current_stream().wait_stream(self.stream)
input = next_input
target = next_target
@ -68,7 +68,8 @@ def create_loader(
batch_size,
is_training=False,
use_prefetcher=True,
random_erasing=0.,
rand_erase_prob=0.,
rand_erase_pp=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
@ -110,7 +111,8 @@ def create_loader(
if use_prefetcher:
loader = PrefetchLoader(
loader,
random_erasing=random_erasing if is_training else 0.,
rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_pp=rand_erase_pp,
mean=mean,
std=std)

@ -110,7 +110,7 @@ class RandomErasingTorch:
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color:
c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_()
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
elif not self.per_pixel:
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
if w < img_w and h < img_h:
@ -118,7 +118,7 @@ class RandomErasingTorch:
left = random.randint(0, img_w - w)
if self.per_pixel:
img[:, top:top + h, left:left + w] = torch.empty(
(chan, h, w), dtype=batch.dtype).cuda().normal_()
(chan, h, w), dtype=batch.dtype).normal_().cuda()
else:
img[:, top:top + h, left:left + w] = c
break

@ -61,6 +61,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)')
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
help='Random erase prob (default: 0.4)')
parser.add_argument('--repp', action='store_true', default=False,
help='Random erase per-pixel (default: False)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
@ -196,7 +200,8 @@ def main():
batch_size=args.batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.3,
rand_erase_prob=args.reprob,
rand_erase_pp=args.repp,
mean=data_mean,
std=data_std,
num_workers=args.workers,

Loading…
Cancel
Save