Better vmap compat across recent torch versions

pull/1612/head
Ross Wightman 2 years ago
parent 130458988a
commit 7c846d9970

@ -84,6 +84,14 @@ def resample_patch_embed(
Resized patch embedding kernel. Resized patch embedding kernel.
""" """
import numpy as np import numpy as np
try:
import functorch
vmap = functorch.vmap
except ImportError:
if hasattr(torch, 'vmap'):
vmap = torch.vmap
else:
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
assert len(patch_embed.shape) == 4, "Four dimensions expected" assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw" assert len(new_size) == 2, "New shape should only be hw"
@ -115,7 +123,7 @@ def resample_patch_embed(
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size) return resampled_kernel.reshape(new_size)
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1) v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
return v_resample_kernel(patch_embed) return v_resample_kernel(patch_embed)

@ -1 +1 @@
__version__ = '0.8.2dev0' __version__ = '0.8.3dev0'

Loading…
Cancel
Save