Switch random erasing to doing normal_() on CPU to avoid instability, remove a debug print

pull/16/head
Ross Wightman 6 years ago
parent c6b32cbe73
commit 65a634626f

@ -6,12 +6,13 @@ import torch
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
# paths, flip the order so normal is run on CPU if this becomes a problem # paths, flip the order so normal is run on CPU if this becomes a problem
# ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device) # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
# will revert back to doing normal_() on GPU when it's in next release
if per_pixel: if per_pixel:
return torch.empty( return torch.empty(
patch_size, dtype=dtype, device=device).normal_() patch_size, dtype=dtype).normal_().to(device=device)
elif rand_color: elif rand_color:
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
else: else:
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)

@ -171,7 +171,6 @@ def transforms_imagenet_train(
else: else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3 color_jitter = (float(color_jitter),) * 3
print(*color_jitter)
tfl = [ tfl = [
RandomResizedCropAndInterpolation( RandomResizedCropAndInterpolation(

Loading…
Cancel
Save