From 42bd8f7bcbc7741cacc70864a54dc231116f5a05 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 14 Jan 2023 21:16:29 -0800 Subject: [PATCH] Add convnext_base CLIP image tower weights for fine-tuning / features --- timm/models/convnext.py | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 26a3d560..b814119c 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -205,6 +205,7 @@ class ConvNeXt(nn.Module): use_grn=False, act_layer='gelu', norm_layer=None, + norm_eps=None, drop_rate=0., drop_path_rate=0., ): @@ -236,10 +237,15 @@ class ConvNeXt(nn.Module): if norm_layer is None: norm_layer = LayerNorm2d norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) self.num_classes = num_classes self.drop_rate = drop_rate @@ -250,7 +256,7 @@ class ConvNeXt(nn.Module): # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), - norm_layer(dims[0]) + norm_layer(dims[0]), ) stem_stride = patch_size else: @@ -376,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model): return state_dict # non-FB checkpoint if 'model' in state_dict: state_dict = state_dict['model'] + out_dict = {} + if 'visual.trunk.stem.0.weight' in state_dict: + out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')} + if 'visual.head.proj.weight' in state_dict: + out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + return out_dict + import re for k, v in state_dict.items(): k = k.replace('downsample_layers.0.', 'stem.') @@ -395,6 +409,7 @@ def checkpoint_filter_fn(state_dict, model): model_shape = model.state_dict()[k].shape v = v.reshape(model_shape) out_dict[k] = v + return out_dict @@ -685,6 +700,28 @@ default_cfgs = generate_default_cfgs({ num_classes=0), 'convnextv2_small.untrained': _cfg(), + + # CLIP based weights, original image tower weights and fine-tunes + 'convnext_base.clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laion2b_augreg': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_augreg_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), })