From 143f8e69b1302011d33283528a4b98d955633448 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 20 Jan 2023 01:53:36 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fd992209..3b21a209 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -175,12 +175,9 @@ class Attention(nn.Module): class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() - ''' self.random_matrix = nn.parameter.Parameter( data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) - ''' - self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C) @@ -444,13 +441,20 @@ class MetaFormerBlock(nn.Module): x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( - self.token_mixer(self.norm1(x)) + self.token_mixer( + self.norm1( + x.permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + ) ) ) x = self.res_scale2(x) + \ self.layer_scale2( self.drop_path2( - self.mlp(self.norm2(x)) + self.mlp(self.norm2( + x.permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + ) ) ) #x = x.view(B, C, H, W)