Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent c8559878fb
commit 62eb0c4346

@ -17,7 +17,8 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Tuple
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
import torch
import torch.nn as nn
@ -37,14 +38,18 @@ __all__ = ['DaViT']
class MySequential(nn.Sequential):
def forward(self, *inputs):
@Overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
@Overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values():
inputs = module(*inputs)
return inputs
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False):

Loading…
Cancel
Save