|
|
@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, T
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
import torch.Tensor as Tensor
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
@ -33,9 +34,17 @@ from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
|
|
|
|
def forward(self, *inputs):
|
|
|
|
|
|
|
|
for module in self._modules.values():
|
|
|
|
|
|
|
|
if type(inputs) == tuple:
|
|
|
|
|
|
|
|
inputs = module(*inputs)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
inputs = module(inputs)
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
@overload
|
|
|
|
@overload
|
|
|
|