Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent e5847a93c4
commit 3a422f735a

@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# This source code is licensed under the MIT license
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, List
import torch
@ -43,21 +43,7 @@ class MySequential(nn.Sequential):
else:
inputs = module(inputs)
return inputs
'''
class MySequential(nn.Sequential):
@overload
def forward(self, inputs : Tensor):
for module in self._modules.values():
inputs = module(inputs)
return inputs
@overload
def forward(self, inputs : Tuple[Tensor, Tensor]):
for module in self._modules.values():
inputs = module(*inputs)
return inputs
'''
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False):
super(ConvPosEnc, self).__init__()

Loading…
Cancel
Save