diff --git a/timm/models/convit.py b/timm/models/convit.py index 60ba59fc..695c7c4f 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -64,6 +64,7 @@ class GPSA(nn.Module): self.dim = dim head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 + self.locality_strength = locality_strength self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) self.v = nn.Linear(dim, dim, bias=qkv_bias) @@ -72,7 +73,6 @@ class GPSA(nn.Module): self.proj = nn.Linear(dim, dim) self.pos_proj = nn.Linear(3, num_heads) self.proj_drop = nn.Dropout(proj_drop) - self.locality_strength = locality_strength self.gating_param = nn.Parameter(torch.ones(self.num_heads)) self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 462ad8fe..279780be 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -92,8 +92,7 @@ class LinearBottleneck(nn.Module): if self.use_shortcut: if self.drop_path is not None: x = self.drop_path(x) - - x[:, 0:self.in_channels] += shortcut + x[:, 0:self.in_channels] += shortcut return x