Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent c8559878fb
commit 62eb0c4346

@ -17,7 +17,8 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
import itertools import itertools
from typing import Tuple from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -37,12 +38,16 @@ __all__ = ['DaViT']
class MySequential(nn.Sequential): class MySequential(nn.Sequential):
def forward(self, *inputs): @Overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
inputs = module(inputs)
return inputs
@Overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values(): for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs) inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs return inputs

Loading…
Cancel
Save