diff --git a/timm/layers/patch_embed.py b/timm/layers/patch_embed.py index b7416260..764519f2 100644 --- a/timm/layers/patch_embed.py +++ b/timm/layers/patch_embed.py @@ -84,6 +84,14 @@ def resample_patch_embed( Resized patch embedding kernel. """ import numpy as np + try: + import functorch + vmap = functorch.vmap + except ImportError: + if hasattr(torch, 'vmap'): + vmap = torch.vmap + else: + assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing." assert len(patch_embed.shape) == 4, "Four dimensions expected" assert len(new_size) == 2, "New shape should only be hw" @@ -115,7 +123,7 @@ def resample_patch_embed( resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) return resampled_kernel.reshape(new_size) - v_resample_kernel = torch.vmap(torch.vmap(resample_kernel, 0, 0), 1, 1) + v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1) return v_resample_kernel(patch_embed) diff --git a/timm/version.py b/timm/version.py index c9cc324d..d1a9131b 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.2dev0' +__version__ = '0.8.3dev0'