Add convnext_base CLIP image tower weights for fine-tuning / features

pull/1643/head
Ross Wightman 1 year ago
parent 65aea97067
commit 42bd8f7bcb

@ -205,6 +205,7 @@ class ConvNeXt(nn.Module):
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_eps=None,
drop_rate=0.,
drop_path_rate=0.,
):
@ -236,10 +237,15 @@ class ConvNeXt(nn.Module):
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -250,7 +256,7 @@ class ConvNeXt(nn.Module):
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0])
norm_layer(dims[0]),
)
stem_stride = patch_size
else:
@ -376,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
return out_dict
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
@ -395,6 +409,7 @@ def checkpoint_filter_fn(state_dict, model):
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
@ -685,6 +700,28 @@ default_cfgs = generate_default_cfgs({
num_classes=0),
'convnextv2_small.untrained': _cfg(),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
})

Loading…
Cancel
Save