pull/297/head
parent
5f4b6076d8
commit
460eba7f24
@ -0,0 +1,18 @@
|
||||
""" Linear layer (alternate definition)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
||||
|
||||
Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
|
||||
weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
|
||||
"""
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if torch.jit.is_scripting():
|
||||
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
|
||||
else:
|
||||
return F.linear(input, self.weight, self.bias)
|
Loading…
Reference in new issue