|
|
@ -122,8 +122,8 @@ class LayerScaleBlockClassAttn(nn.Module):
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
|
|
|
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
|
|
|
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, x_cls):
|
|
|
|
def forward(self, x, x_cls):
|
|
|
|
u = torch.cat((x_cls, x), dim=1)
|
|
|
|
u = torch.cat((x_cls, x), dim=1)
|
|
|
@ -189,8 +189,8 @@ class LayerScaleBlock(nn.Module):
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
|
|
|
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
|
|
|
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
|
|
|
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
|
|
|