diff --git a/tests/test_models.py b/tests/test_models.py index f5698462..71d643dd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -99,10 +99,11 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == model.num_features # test model forward without pooling and classifier + model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 4 if not isinstance(model, timm.models.MobileNetV3): - model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through - outputs = model.forward(input_tensor) - assert len(outputs.shape) == 4 + # FIXME mobilenetv3 forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier and first convolution names match those in default_cfg diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e0ad7c95..e20b6d34 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -101,7 +101,7 @@ class MobileNetV3(nn.Module): head_chs = builder.in_chs # Head + Pooling - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if global_pool else nn.Identity() + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_pooled_chs = head_chs * self.global_pool.feat_mult() self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) @@ -122,7 +122,7 @@ class MobileNetV3(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation - assert global_pool == self.global_pool.pool_type + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -136,7 +136,9 @@ class MobileNetV3(nn.Module): return x def forward(self, x): - x = self.forward_features(x).flatten(1) + x = self.forward_features(x) + if not self.global_pool.is_identity(): + x = x.flatten(1) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x)