Make gcvit window size ratio based to improve resolution changing support #1449. Change default init to original.

pull/804/merge
Ross Wightman 2 years ago
parent c45c6ee8e4
commit f489f02ad1

@ -30,8 +30,8 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \
ClassifierHead, LayerNorm2d, _assert
from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
get_attn, get_act_layer, get_norm_layer, _assert
from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
@ -321,7 +321,7 @@ class GlobalContextVitStage(nn.Module):
depth: int,
num_heads: int,
feat_size: Tuple[int, int],
window_size: int,
window_size: Tuple[int, int],
downsample: bool = True,
global_norm: bool = False,
stage_norm: bool = False,
@ -347,8 +347,9 @@ class GlobalContextVitStage(nn.Module):
else:
self.downsample = nn.Identity()
self.feat_size = feat_size
window_size = to_2tuple(window_size)
feat_levels = int(math.log2(min(feat_size) / window_size))
feat_levels = int(math.log2(min(feat_size) / min(window_size)))
self.global_block = FeatureBlock(dim, feat_levels)
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
@ -400,7 +401,8 @@ class GlobalContextVit(nn.Module):
num_classes: int = 1000,
global_pool: str = 'avg',
img_size: Tuple[int, int] = 224,
window_size: Tuple[int, ...] = (7, 7, 14, 7),
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
window_size: Tuple[int, ...] = None,
embed_dim: int = 64,
depths: Tuple[int, ...] = (3, 4, 19, 5),
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
@ -411,7 +413,7 @@ class GlobalContextVit(nn.Module):
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init='vit',
weight_init='',
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
@ -429,6 +431,11 @@ class GlobalContextVit(nn.Module):
self.drop_rate = drop_rate
num_stages = len(depths)
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
if window_size is not None:
window_size = to_ntuple(num_stages)(window_size)
else:
assert window_ratio is not None
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
self.stem = Stem(
in_chs=in_chans,
@ -480,7 +487,7 @@ class GlobalContextVit(nn.Module):
nn.init.zeros_(module.bias)
else:
if isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02)
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
@ -490,7 +497,6 @@ class GlobalContextVit(nn.Module):
k for k, _ in self.named_parameters()
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(3, 6, 12, 24),
window_size=(7, 7, 14, 7),
embed_dim=96,
mlp_ratio=2,
layer_scale=1e-5,
@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(3, 4, 19, 5),
num_heads=(4, 8, 16, 32),
window_size=(7, 7, 14, 7),
embed_dim=128,
mlp_ratio=2,
layer_scale=1e-5,

Loading…
Cancel
Save