|
|
@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
nn.init.ones_(m.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]):
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
@ -363,11 +363,12 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
|
|
gs_new = int(math.sqrt(ntok_new))
|
|
|
|
if not len(gs_new): # backwards compatibility
|
|
|
|
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
|
|
|
|
gs_new = [int(math.sqrt(ntok_new))]*2
|
|
|
|
|
|
|
|
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
return posemb
|
|
|
|
return posemb
|
|
|
|
|
|
|
|
|
|
|
@ -385,7 +386,8 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
|
|
|
|
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1),
|
|
|
|
|
|
|
|
model.patch_embed.grid_size)
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|