Fix MobileNetV3 crash with global_pool='', output consistent with other models but not equivalent due to efficient head.

pull/227/head
Ross Wightman 4 years ago
parent fc8b8afb6f
commit 470220b1f4

@ -99,10 +99,11 @@ def test_model_default_cfgs(model_name, batch_size):
assert outputs.shape[-1] == model.num_features
# test model forward without pooling and classifier
if not isinstance(model, timm.models.MobileNetV3):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3):
# FIXME mobilenetv3 forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier and first convolution names match those in default_cfg

@ -101,7 +101,7 @@ class MobileNetV3(nn.Module):
head_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity()
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
@ -122,7 +122,7 @@ class MobileNetV3(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
assert global_pool == self.global_pool.pool_type
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
@ -136,7 +136,9 @@ class MobileNetV3(nn.Module):
return x
def forward(self, x):
x = self.forward_features(x).flatten(1)
x = self.forward_features(x)
if not self.global_pool.is_identity():
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)

Loading…
Cancel
Save