diff --git a/timm/models/davit.py b/timm/models/davit.py index 4c3acb32..41af0ebe 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -523,9 +523,9 @@ class DaViT(nn.Module): features.append(features[-1]) sizes.append(sizes[-1]) - ''' - + + ''' for block_index, block_param in enumerate(self.architecture): branch_ids = sorted(set(block_param)) @@ -562,7 +562,12 @@ class DaViT(nn.Module): H, W = sizes[i] out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() outs.append(out) + + + ''' + + # non-normalized pyramid features + corresponding sizes return tuple(features), tuple(sizes)