diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index e944f22c..5eed1387 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device=' # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 - # will revert back to doing normal_() on GPU when it's in next release if per_pixel: - return torch.empty( - patch_size, dtype=dtype).normal_().to(device=device) + return torch.empty(patch_size, dtype=dtype, device=device).normal_() elif rand_color: - return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)