diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index bad40bd6..fb375e2c 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -30,8 +30,8 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply -from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \ - ClassifierHead, LayerNorm2d, _assert +from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ + get_attn, get_act_layer, get_norm_layer, _assert from .registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location @@ -321,7 +321,7 @@ class GlobalContextVitStage(nn.Module): depth: int, num_heads: int, feat_size: Tuple[int, int], - window_size: int, + window_size: Tuple[int, int], downsample: bool = True, global_norm: bool = False, stage_norm: bool = False, @@ -347,8 +347,9 @@ class GlobalContextVitStage(nn.Module): else: self.downsample = nn.Identity() self.feat_size = feat_size + window_size = to_2tuple(window_size) - feat_levels = int(math.log2(min(feat_size) / window_size)) + feat_levels = int(math.log2(min(feat_size) / min(window_size))) self.global_block = FeatureBlock(dim, feat_levels) self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() @@ -400,7 +401,8 @@ class GlobalContextVit(nn.Module): num_classes: int = 1000, global_pool: str = 'avg', img_size: Tuple[int, int] = 224, - window_size: Tuple[int, ...] = (7, 7, 14, 7), + window_ratio: Tuple[int, ...] = (32, 32, 16, 32), + window_size: Tuple[int, ...] = None, embed_dim: int = 64, depths: Tuple[int, ...] = (3, 4, 19, 5), num_heads: Tuple[int, ...] = (2, 4, 8, 16), @@ -411,7 +413,7 @@ class GlobalContextVit(nn.Module): proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - weight_init='vit', + weight_init='', act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_layer_cl: str = 'layernorm', @@ -429,6 +431,11 @@ class GlobalContextVit(nn.Module): self.drop_rate = drop_rate num_stages = len(depths) self.num_features = int(embed_dim * 2 ** (num_stages - 1)) + if window_size is not None: + window_size = to_ntuple(num_stages)(window_size) + else: + assert window_ratio is not None + window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)]) self.stem = Stem( in_chs=in_chans, @@ -480,7 +487,7 @@ class GlobalContextVit(nn.Module): nn.init.zeros_(module.bias) else: if isinstance(module, nn.Linear): - trunc_normal_tf_(module.weight, std=.02) + nn.init.normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) @@ -490,7 +497,6 @@ class GlobalContextVit(nn.Module): k for k, _ in self.named_parameters() if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} - @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( @@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 19, 5), num_heads=(3, 6, 12, 24), - window_size=(7, 7, 14, 7), embed_dim=96, mlp_ratio=2, layer_scale=1e-5, @@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 19, 5), num_heads=(4, 8, 16, 32), - window_size=(7, 7, 14, 7), embed_dim=128, mlp_ratio=2, layer_scale=1e-5,