From 122621daefe4ea1f1c3454139968e67aa2ea5f80 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 Feb 2023 08:53:50 -0800 Subject: [PATCH] Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit --- timm/models/maxxvit.py | 5 +++++ timm/models/vision_transformer.py | 3 +++ timm/models/vision_transformer_relpos.py | 3 +++ 3 files changed, 11 insertions(+) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 5a164e88..1f8f5f24 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List import torch from torch import nn +from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead @@ -140,6 +141,8 @@ class MaxxVitCfg: class Attention2d(nn.Module): + fast_attn: Final[bool] + """ multi-head attention for 2D NCHW tensors""" def __init__( self, @@ -208,6 +211,8 @@ class Attention2d(nn.Module): class AttentionCl(nn.Module): """ Channels-last multi-head attention (B, ..., C) """ + fast_attn: Final[bool] + def __init__( self, dim: int, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 95d87126..dcd0a6fa 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -33,6 +33,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint +from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD @@ -51,6 +52,8 @@ _logger = logging.getLogger(__name__) class Attention(nn.Module): + fast_attn: Final[bool] + def __init__( self, dim, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index f9fede53..835b213a 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -11,6 +11,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from torch.jit import Final from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -25,6 +26,8 @@ _logger = logging.getLogger(__name__) class RelPosAttention(nn.Module): + fast_attn: Final[bool] + def __init__( self, dim,