|
|
|
@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
|
|
|
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
|
|
|
|
'fixed_input_size': True,
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
@ -106,7 +106,7 @@ class Downsample2d(nn.Module):
|
|
|
|
|
dim_out=None,
|
|
|
|
|
reduction='conv',
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
norm_layer=LayerNorm2d,
|
|
|
|
|
norm_layer=LayerNorm2d, # NOTE in NCHW
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dim_out = dim_out or dim
|
|
|
|
@ -163,12 +163,10 @@ class Stem(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
in_chs: int = 3,
|
|
|
|
|
out_chs: int = 96,
|
|
|
|
|
act_layer: str = 'gelu',
|
|
|
|
|
norm_layer: str = 'layernorm2d', # NOTE norm for NCHW
|
|
|
|
|
act_layer: Callable = nn.GELU,
|
|
|
|
|
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
|
norm_layer = get_norm_layer(norm_layer)
|
|
|
|
|
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
|
|
|
|
|
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
|
|
|
|
|
|
|
|
|
@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module):
|
|
|
|
|
proj_drop: float = 0.,
|
|
|
|
|
attn_drop: float = 0.,
|
|
|
|
|
drop_path: Union[List[float], float] = 0.0,
|
|
|
|
|
act_layer: str = 'gelu',
|
|
|
|
|
norm_layer: str = 'layernorm2d',
|
|
|
|
|
norm_layer_cl: str = 'layernorm',
|
|
|
|
|
act_layer: Callable = nn.GELU,
|
|
|
|
|
norm_layer: Callable = nn.LayerNorm,
|
|
|
|
|
norm_layer_cl: Callable = LayerNorm2d,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
|
norm_layer = get_norm_layer(norm_layer)
|
|
|
|
|
norm_layer_cl = get_norm_layer(norm_layer_cl)
|
|
|
|
|
|
|
|
|
|
if downsample:
|
|
|
|
|
self.downsample = Downsample2d(
|
|
|
|
|
dim=dim,
|
|
|
|
@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module):
|
|
|
|
|
act_layer: str = 'gelu',
|
|
|
|
|
norm_layer: str = 'layernorm2d',
|
|
|
|
|
norm_layer_cl: str = 'layernorm',
|
|
|
|
|
norm_eps: float = 1e-5,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
act_layer = get_act_layer(act_layer)
|
|
|
|
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
|
|
|
|
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
|
|
|
|
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
|
|
|
|
self.global_pool = global_pool
|
|
|
|
@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module):
|
|
|
|
|
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
|
|
|
|
|
|
|
|
|
self.stem = Stem(
|
|
|
|
|
in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer)
|
|
|
|
|
in_chs=in_chans,
|
|
|
|
|
out_chs=embed_dim,
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
norm_layer=norm_layer
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
|
stages = []
|
|
|
|
|