|
|
@ -494,7 +494,7 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
|
|
assert len(gs_new) >= 2
|
|
|
|
assert len(gs_new) >= 2
|
|
|
|
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
|
|
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
return posemb
|
|
|
|
return posemb
|
|
|
|