Work around casting issue with combination of native torch AMP and torchscript for Linear layers

pull/297/head
Ross Wightman 4 years ago
parent 5f4b6076d8
commit 460eba7f24

@ -6,6 +6,7 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
@ -21,7 +22,8 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
elif use_conv: elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True) fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
else: else:
fc = nn.Linear(num_pooled_features, num_classes, bias=True) # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
fc = Linear(num_pooled_features, num_classes, bias=True)
return global_pool, fc return global_pool, fc

@ -0,0 +1,18 @@
""" Linear layer (alternate definition)
"""
import torch
import torch.nn.functional as F
from torch import nn as nn
class Linear(nn.Linear):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
else:
return F.linear(input, self.weight, self.bias)

@ -367,7 +367,6 @@ def main():
if args.torchscript: if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
model = torch.jit.script(model) model = torch.jit.script(model)
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)

Loading…
Cancel
Save