From 20b2d4b69dae2ec185a77a50cf1d38d55d94b657 Mon Sep 17 00:00:00 2001 From: Ying Jin Date: Sun, 11 Jul 2021 22:08:07 -0700 Subject: [PATCH] Use bicubic interpolation in resize_pos_embed() --- timm/models/nest.py | 2 +- timm/models/vision_transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/nest.py b/timm/models/nest.py index 191d15c7..fe0645cc 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -377,7 +377,7 @@ def resize_pos_embed(posemb, posemb_new): size_new = int(math.sqrt(num_blocks_new*seq_length_new)) # First change to (1, C, H, W) posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2) - posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bilinear') + posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False) # Now change to new (1, T, N, C) posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new))) return posemb diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index e0c904f7..e3bcb6fe 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -494,7 +494,7 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): assert len(gs_new) >= 2 _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb