Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit

pull/1680/head
Ross Wightman 1 year ago committed by Ross Wightman
parent 621e1b2182
commit 122621daef

@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
import torch
from torch import nn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
@ -140,6 +141,8 @@ class MaxxVitCfg:
class Attention2d(nn.Module):
fast_attn: Final[bool]
""" multi-head attention for 2D NCHW tensors"""
def __init__(
self,
@ -208,6 +211,8 @@ class Attention2d(nn.Module):
class AttentionCl(nn.Module):
""" Channels-last multi-head attention (B, ..., C) """
fast_attn: Final[bool]
def __init__(
self,
dim: int,

@ -33,6 +33,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
@ -51,6 +52,8 @@ _logger = logging.getLogger(__name__)
class Attention(nn.Module):
fast_attn: Final[bool]
def __init__(
self,
dim,

@ -11,6 +11,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.jit import Final
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -25,6 +26,8 @@ _logger = logging.getLogger(__name__)
class RelPosAttention(nn.Module):
fast_attn: Final[bool]
def __init__(
self,
dim,

Loading…
Cancel
Save