Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit

pull/1674/head
Ross Wightman 2 years ago
parent a9739258f4
commit b6eb652924

@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
import torch import torch
from torch import nn from torch import nn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
@ -140,6 +141,8 @@ class MaxxVitCfg:
class Attention2d(nn.Module): class Attention2d(nn.Module):
fast_attn: Final[bool]
""" multi-head attention for 2D NCHW tensors""" """ multi-head attention for 2D NCHW tensors"""
def __init__( def __init__(
self, self,
@ -208,6 +211,8 @@ class Attention2d(nn.Module):
class AttentionCl(nn.Module): class AttentionCl(nn.Module):
""" Channels-last multi-head attention (B, ..., C) """ """ Channels-last multi-head attention (B, ..., C) """
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,

@ -33,6 +33,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
@ -51,6 +52,8 @@ _logger = logging.getLogger(__name__)
class Attention(nn.Module): class Attention(nn.Module):
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim, dim,

@ -11,6 +11,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.jit import Final
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -25,6 +26,8 @@ _logger = logging.getLogger(__name__)
class RelPosAttention(nn.Module): class RelPosAttention(nn.Module):
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim, dim,

Loading…
Cancel
Save