diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index c16725ae..e66f7b95 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -6,12 +6,13 @@ import torch def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem - # ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device) + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + # will revert back to doing normal_() on GPU when it's in next release if per_pixel: return torch.empty( - patch_size, dtype=dtype, device=device).normal_() + patch_size, dtype=dtype).normal_().to(device=device) elif rand_color: - return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 1e1b054a..13a6ff01 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -171,7 +171,6 @@ def transforms_imagenet_train( else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 - print(*color_jitter) tfl = [ RandomResizedCropAndInterpolation(