Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent c902731699
commit e8fd52e39f

@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, List
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
import torch
@ -36,11 +36,28 @@ from .registry import register_model
__all__ = ['DaViT']
class MySequential(nn.Sequential):
def forward(self, inputs : Tuple):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
'''
class MySequential(nn.Sequential):
@overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
inputs = module(inputs)
return inputs
@overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values():
inputs = module(*inputs)
return inputs
'''
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False):
super(ConvPosEnc, self).__init__()
@ -539,6 +556,7 @@ def checkpoint_filter_fn(state_dict, model):
state_dict = state_dict['state_dict']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('head.', 'head.fc.')

Loading…
Cancel
Save