Minor adjustment, mutable default arg, extra check of valid len...

pull/660/head
Ross Wightman 4 years ago
parent be0abfbcce
commit 30b9880d06

@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]): def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
@ -363,8 +363,9 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]):
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid))) gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))]*2 gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')

Loading…
Cancel
Save