A few more changes for 0.3.2 maint release. Linear layer change for mobilenetv3 and inception_v3, support no bias for linear wrapper.

pull/297/head
Ross Wightman 4 years ago
parent 6504a42832
commit 2ed8f24715

@ -121,7 +121,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
@pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True."""

@ -14,7 +14,7 @@ import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame
from .layers import Conv2dSame, Linear
_logger = logging.getLogger(__name__)
@ -234,7 +234,7 @@ def adapt_model_from_string(parent_module, model_string):
if isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = nn.Linear(
new_fc = Linear(
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):

@ -10,7 +10,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import trunc_normal_, create_classifier
from .layers import trunc_normal_, create_classifier, Linear
def _cfg(url='', **kwargs):
@ -250,7 +250,7 @@ class InceptionAux(nn.Module):
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc = Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):

@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d
from .padding import get_padding

@ -13,6 +13,7 @@ class Linear(nn.Linear):
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
return F.linear(input, self.weight.to(dtype=input.dtype), self.bias.to(dtype=input.dtype))
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
else:
return F.linear(input, self.weight, self.bias)

@ -18,7 +18,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d, get_act_fn, hard_sigmoid
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model
__all__ = ['MobileNetV3']
@ -105,7 +105,7 @@ class MobileNetV3(nn.Module):
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self)
@ -123,7 +123,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)

@ -327,6 +327,7 @@ def main():
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
if args.local_rank == 0:

Loading…
Cancel
Save