From 23d0beb0d336ea86ee8dc39d76c2db830194624f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 06:06:20 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 70971699..68b8f66b 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -33,7 +33,7 @@ from .registry import register_model __all__ = ['DaViT'] - +''' class MySequential(nn.Sequential): def forward(self, inputs : Tensor, size : Tuple[int, int]): @@ -42,7 +42,7 @@ class MySequential(nn.Sequential): inputs : Tensor = output[0] size : Tuple[int, int] = output[1] return inputs - +''' class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -421,8 +421,8 @@ class DaViT(nn.Module): for stage_id, stage_param in enumerate(self.architecture): layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - stage = MySequential(*[ - MySequential(*[ + stage = nn.Sequential(*[ + nn.Sequential(*[ ChannelBlock( dim=self.embed_dims[item], num_heads=self.num_heads[item], @@ -492,8 +492,8 @@ class DaViT(nn.Module): for patch_layer, stage in zip(self.patch_embeds, self.stages): features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) - for _, block in enumerate(stage): - for _, layer in enumerate(block): + for block in stage: + for layer in block: if self.grad_checkpointing and not torch.jit.is_scripting(): features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) else: