|
|
@ -205,6 +205,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
use_grn=False,
|
|
|
|
use_grn=False,
|
|
|
|
act_layer='gelu',
|
|
|
|
act_layer='gelu',
|
|
|
|
norm_layer=None,
|
|
|
|
norm_layer=None,
|
|
|
|
|
|
|
|
norm_eps=None,
|
|
|
|
drop_rate=0.,
|
|
|
|
drop_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
):
|
|
|
|
):
|
|
|
@ -236,10 +237,15 @@ class ConvNeXt(nn.Module):
|
|
|
|
if norm_layer is None:
|
|
|
|
if norm_layer is None:
|
|
|
|
norm_layer = LayerNorm2d
|
|
|
|
norm_layer = LayerNorm2d
|
|
|
|
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
|
|
|
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
|
|
|
|
|
|
|
if norm_eps is not None:
|
|
|
|
|
|
|
|
norm_layer = partial(norm_layer, eps=norm_eps)
|
|
|
|
|
|
|
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
assert conv_mlp,\
|
|
|
|
assert conv_mlp,\
|
|
|
|
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
|
|
|
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
|
|
|
norm_layer_cl = norm_layer
|
|
|
|
norm_layer_cl = norm_layer
|
|
|
|
|
|
|
|
if norm_eps is not None:
|
|
|
|
|
|
|
|
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
@ -250,7 +256,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
|
|
|
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
|
|
|
self.stem = nn.Sequential(
|
|
|
|
self.stem = nn.Sequential(
|
|
|
|
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
|
|
|
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
|
|
|
norm_layer(dims[0])
|
|
|
|
norm_layer(dims[0]),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
stem_stride = patch_size
|
|
|
|
stem_stride = patch_size
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -376,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
return state_dict # non-FB checkpoint
|
|
|
|
return state_dict # non-FB checkpoint
|
|
|
|
if 'model' in state_dict:
|
|
|
|
if 'model' in state_dict:
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
|
|
|
|
if 'visual.trunk.stem.0.weight' in state_dict:
|
|
|
|
|
|
|
|
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
|
|
|
|
|
|
|
|
if 'visual.head.proj.weight' in state_dict:
|
|
|
|
|
|
|
|
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
|
|
|
|
|
|
|
|
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
|
|
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
k = k.replace('downsample_layers.0.', 'stem.')
|
|
|
|
k = k.replace('downsample_layers.0.', 'stem.')
|
|
|
@ -395,6 +409,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
model_shape = model.state_dict()[k].shape
|
|
|
|
model_shape = model.state_dict()[k].shape
|
|
|
|
v = v.reshape(model_shape)
|
|
|
|
v = v.reshape(model_shape)
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
|
|
|
|
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -685,6 +700,28 @@ default_cfgs = generate_default_cfgs({
|
|
|
|
num_classes=0),
|
|
|
|
num_classes=0),
|
|
|
|
|
|
|
|
|
|
|
|
'convnextv2_small.untrained': _cfg(),
|
|
|
|
'convnextv2_small.untrained': _cfg(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# CLIP based weights, original image tower weights and fine-tunes
|
|
|
|
|
|
|
|
'convnext_base.clip_laion2b': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
|
|
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
|
|
|
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
|
|
|
|
|
|
|
'convnext_base.clip_laion2b_augreg': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
|
|
|
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
|
|
|
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
|
|
|
|
|
|
|
'convnext_base.clip_laiona': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
|
|
|
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
|
|
|
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
|
|
|
|
|
|
|
'convnext_base.clip_laiona_320': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
|
|
|
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
|
|
|
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
|
|
|
|
|
|
|
'convnext_base.clip_laiona_augreg_320': _cfg(
|
|
|
|
|
|
|
|
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
|
|
|
|
|
|
|
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
|
|
|
|
|
|
|
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|