diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5b98cba8..d0c2b401 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -180,7 +180,7 @@ class RandomMixing(nn.Module): data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) ''' - self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens) + self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens)) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C)