Disable use of timm nn.Linear wrapper since AMP autocast + torchscript use appears fixed

pull/659/merge
Ross Wightman 3 years ago
parent a0b2657497
commit 214c84a235

@ -6,7 +6,6 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .linear import Linear
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
@ -26,8 +25,7 @@ def _create_fc(num_features, num_classes, use_conv=False):
elif use_conv: elif use_conv:
fc = nn.Conv2d(num_features, num_classes, 1, bias=True) fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
else: else:
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue fc = nn.Linear(num_features, num_classes, bias=True)
fc = Linear(num_features, num_classes, bias=True)
return fc return fc

Loading…
Cancel
Save