Fix a few issues loading pretrained vit/bit npz weights w/ num_classes=0 __init__ arg. Missed a few other small classifier handling detail on Mlp, GhostNet, Levit. Should fix #713

vit_and_bit_test_fixes
Ross Wightman 3 years ago
parent dc422820ec
commit b41cffaa93

@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg
classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
# check classifier name matches default_cfg
classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

@ -182,7 +182,7 @@ class GhostNet(nn.Module):
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
self.act2 = nn.ReLU(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(out_chs, num_classes)
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
def get_classifier(self):
return self.classifier

@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model):
state_dict = state_dict['model']
D = model.state_dict()
for k in state_dict.keys():
if D[k].ndim == 4 and state_dict[k].ndim == 2:
if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
state_dict[k] = state_dict[k][:, :, None, None]
return state_dict

@ -266,7 +266,7 @@ class MlpMixer(nn.Module):
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(nlhb=nlhb)

@ -424,7 +424,8 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/
model.stem.conv.weight.copy_(stem_conv_w)
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
if isinstance(model.head.fc, nn.Conv2d) and \
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(model.stages.named_children()):

@ -237,7 +237,6 @@ class Visformer(nn.Module):
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(self.num_features)
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.head = nn.Linear(self.num_features, num_classes)
# weights init
if self.pos_embed:

@ -448,7 +448,7 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
for i, block in enumerate(model.blocks.children()):

Loading…
Cancel
Save