|
|
|
@ -84,6 +84,14 @@ def resample_patch_embed(
|
|
|
|
|
Resized patch embedding kernel.
|
|
|
|
|
"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
try:
|
|
|
|
|
import functorch
|
|
|
|
|
vmap = functorch.vmap
|
|
|
|
|
except ImportError:
|
|
|
|
|
if hasattr(torch, 'vmap'):
|
|
|
|
|
vmap = torch.vmap
|
|
|
|
|
else:
|
|
|
|
|
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
|
|
|
|
|
|
|
|
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
|
|
|
|
assert len(new_size) == 2, "New shape should only be hw"
|
|
|
|
@ -115,7 +123,7 @@ def resample_patch_embed(
|
|
|
|
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
|
|
|
|
return resampled_kernel.reshape(new_size)
|
|
|
|
|
|
|
|
|
|
v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1)
|
|
|
|
|
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
|
|
|
|
return v_resample_kernel(patch_embed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|