|
|
@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
from torch.jit import Final
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
|
|
|
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
|
|
@ -140,6 +141,8 @@ class MaxxVitCfg:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention2d(nn.Module):
|
|
|
|
class Attention2d(nn.Module):
|
|
|
|
|
|
|
|
fast_attn: Final[bool]
|
|
|
|
|
|
|
|
|
|
|
|
""" multi-head attention for 2D NCHW tensors"""
|
|
|
|
""" multi-head attention for 2D NCHW tensors"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -208,6 +211,8 @@ class Attention2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionCl(nn.Module):
|
|
|
|
class AttentionCl(nn.Module):
|
|
|
|
""" Channels-last multi-head attention (B, ..., C) """
|
|
|
|
""" Channels-last multi-head attention (B, ..., C) """
|
|
|
|
|
|
|
|
fast_attn: Final[bool]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
dim: int,
|
|
|
|
dim: int,
|
|
|
|