Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent cf1cc8b1d1
commit e5847a93c4

@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
import torch
@ -34,18 +34,16 @@ from .registry import register_model
__all__ = ['DaViT']
'''
class MySequential(nn.Sequential):
def forward(self, *inputs):
def forward(self, inputs : List[Tensor]):
for module in self._modules.values():
if type(inputs) == tuple:
if len(inputs) > 1:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
'''
class MySequential(nn.Sequential):
@overload
def forward(self, inputs : Tensor):
@ -59,7 +57,7 @@ class MySequential(nn.Sequential):
inputs = module(*inputs)
return inputs
'''
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False):
super(ConvPosEnc, self).__init__()

Loading…
Cancel
Save