Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 6692983832
commit 0ed0e9ac35

@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, T
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.Tensor as Tensor
from .helpers import build_model_with_cfg
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
@ -33,9 +34,17 @@ from .registry import register_model
__all__ = ['DaViT']
'''
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
'''
class MySequential(nn.Sequential):
@overload

Loading…
Cancel
Save