diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 886aa930..b7af3b26 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -141,7 +141,7 @@ def conv3x3(in_planes, out_planes, stride=1): class ConvPatchEmbed(nn.Module): """Image to Patch Embedding using multiple convolutional layers""" - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU): super().__init__() img_size = to_2tuple(img_size) num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) @@ -152,19 +152,19 @@ class ConvPatchEmbed(nn.Module): if patch_size == 16: self.proj = torch.nn.Sequential( conv3x3(in_chans, embed_dim // 8, 2), - nn.GELU(), + act_layer(), conv3x3(embed_dim // 8, embed_dim // 4, 2), - nn.GELU(), + act_layer(), conv3x3(embed_dim // 4, embed_dim // 2, 2), - nn.GELU(), + act_layer(), conv3x3(embed_dim // 2, embed_dim, 2), ) elif patch_size == 8: self.proj = torch.nn.Sequential( - conv3x3(3, embed_dim // 4, 2), - nn.GELU(), + conv3x3(in_chans, embed_dim // 4, 2), + act_layer(), conv3x3(embed_dim // 4, embed_dim // 2, 2), - nn.GELU(), + act_layer(), conv3x3(embed_dim // 2, embed_dim, 2), ) else: @@ -323,7 +323,7 @@ class XCiT(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): """ Args: img_size (int, tuple): input image size @@ -356,9 +356,10 @@ class XCiT(nn.Module): self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU self.patch_embed = ConvPatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.use_pos_embed = use_pos_embed @@ -369,13 +370,13 @@ class XCiT(nn.Module): self.blocks = nn.ModuleList([ XCABlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, eta=eta) + attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta) for _ in range(depth)]) self.cls_attn_blocks = nn.ModuleList([ ClassAttentionBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) + attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) for _ in range(cls_attn_layers)]) # Classifier head