From e8fd52e39f95219ff07973e00178e6970be25f2f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 20:53:01 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 452f30c7..d36a2742 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, List +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union import torch @@ -36,11 +36,28 @@ from .registry import register_model __all__ = ['DaViT'] class MySequential(nn.Sequential): + def forward(self, inputs : Tuple): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs +''' +class MySequential(nn.Sequential): + @overload + def forward(self, inputs : Tensor): + for module in self._modules.values(): + inputs = module(inputs) + return inputs + + @overload def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): inputs = module(*inputs) return inputs +''' class ConvPosEnc(nn.Module): def __init__(self, dim, k=3, act=False, normtype=False): super(ConvPosEnc, self).__init__() @@ -539,6 +556,7 @@ def checkpoint_filter_fn(state_dict, model): state_dict = state_dict['state_dict'] out_dict = {} + import re for k, v in state_dict.items(): k = k.replace('head.', 'head.fc.')