From 909705e7ffd8ac69ca9088dea90f4d09d0578006 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 20 Jul 2022 12:37:41 -0700 Subject: [PATCH] Remove some redundant requires_grad=True from nn.Parameter in third party code --- timm/models/beit.py | 4 ++-- timm/models/cait.py | 8 ++++---- timm/models/poolformer.py | 4 ++-- timm/models/xcit.py | 10 +++++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index a56653dd..a2083a4a 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -182,8 +182,8 @@ class Block(nn.Module): self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values: - self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) else: self.gamma_1, self.gamma_2 = None, None diff --git a/timm/models/cait.py b/timm/models/cait.py index bcc91497..c0892099 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -122,8 +122,8 @@ class LayerScaleBlockClassAttn(nn.Module): self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x, x_cls): u = torch.cat((x_cls, x), dim=1) @@ -189,8 +189,8 @@ class LayerScaleBlock(nn.Module): self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) - self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index a95195b4..09359bc8 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -117,8 +117,8 @@ class PoolFormerBlock(nn.Module): self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() if layer_scale_init_value: - self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) - self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) + self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) else: self.layer_scale_1 = None self.layer_scale_2 = None diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 69b97d64..d70500ce 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -230,8 +230,8 @@ class ClassAttentionBlock(nn.Module): self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) if eta is not None: # LayerScale Initialization (no layerscale when None) - self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) - self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma1 = nn.Parameter(eta * torch.ones(dim)) + self.gamma2 = nn.Parameter(eta * torch.ones(dim)) else: self.gamma1, self.gamma2 = 1.0, 1.0 @@ -308,9 +308,9 @@ class XCABlock(nn.Module): self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) - self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) - self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) - self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma1 = nn.Parameter(eta * torch.ones(dim)) + self.gamma3 = nn.Parameter(eta * torch.ones(dim)) + self.gamma2 = nn.Parameter(eta * torch.ones(dim)) def forward(self, x, H: int, W: int): x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))