diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index e9194f05..89fe5458 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -6,6 +6,7 @@ from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): @@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False elif use_conv: fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) else: - fc = nn.Linear(num_pooled_features, num_classes, bias=True) + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_pooled_features, num_classes, bias=True) return global_pool, fc diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py new file mode 100644 index 00000000..4607f284 --- /dev/null +++ b/timm/models/layers/linear.py @@ -0,0 +1,18 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype)) + else: + return F.linear(input, self.weight, self.bias) \ No newline at end of file diff --git a/train.py b/train.py index 722c79e4..23a8e9b0 100755 --- a/train.py +++ b/train.py @@ -367,7 +367,6 @@ def main(): if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' - # FIXME I ran into a bug w/ AMP + torchscript + Linear layers model = torch.jit.script(model) optimizer = create_optimizer(args, model)