|
|
|
@ -15,22 +15,21 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
|
# FIXME remove unused imports
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
from .fx_features import register_notrace_function, register_notrace_module
|
|
|
|
|
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import DropPath, to_2tuple, trunc_normal_, ClassifierHead, Mlp
|
|
|
|
|
from ._features import FeatureInfo
|
|
|
|
|
from ._features_fx import register_notrace_function
|
|
|
|
|
from ._helpers import build_model_with_cfg
|
|
|
|
|
from ._manipulate import checkpoint_seq
|
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
|
from ._registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|