From e5847a93c4e428fbceee426f3ef831636a6653f9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 20:42:22 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 35b6e00a..61388fe2 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List import torch @@ -34,18 +34,16 @@ from .registry import register_model __all__ = ['DaViT'] -''' + class MySequential(nn.Sequential): - def forward(self, *inputs): + def forward(self, inputs : List[Tensor]): for module in self._modules.values(): - if type(inputs) == tuple: + if len(inputs) > 1: inputs = module(*inputs) else: inputs = module(inputs) return inputs - ''' - class MySequential(nn.Sequential): @overload def forward(self, inputs : Tensor): @@ -59,7 +57,7 @@ class MySequential(nn.Sequential): inputs = module(*inputs) return inputs - +''' class ConvPosEnc(nn.Module): def __init__(self, dim, k=3, act=False, normtype=False): super(ConvPosEnc, self).__init__()