|
|
@ -24,14 +24,12 @@ An implementation of EfficienNet that covers variety of related models with effi
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by Ross Wightman
|
|
|
|
Hacked together by Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
from .efficientnet_builder import *
|
|
|
|
from .efficientnet_builder import *
|
|
|
|
from .feature_hooks import FeatureHooks
|
|
|
|
from .feature_hooks import FeatureHooks
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
from .helpers import load_pretrained, adapt_model_from_file
|
|
|
|
from .helpers import load_pretrained, adapt_model_from_file
|
|
|
|
from .layers import SelectAdaptivePool2d
|
|
|
|
from .layers import SelectAdaptivePool2d
|
|
|
|
from timm.models.layers import create_conv2d
|
|
|
|
from .registry import register_model
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['EfficientNet']
|
|
|
|
__all__ = ['EfficientNet']
|
|
|
|
|
|
|
|
|
|
|
@ -373,8 +371,11 @@ class EfficientNet(nn.Module):
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
self.classifier = nn.Linear(
|
|
|
|
if num_classes:
|
|
|
|
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
|
|
|
num_features = self.num_features * self.global_pool.feat_mult()
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(num_features, num_classes)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.classifier = nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.conv_stem(x)
|
|
|
|
x = self.conv_stem(x)
|
|
|
@ -1187,6 +1188,7 @@ def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
|
|
|
pretrained=pretrained, **kwargs)
|
|
|
|
pretrained=pretrained, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
|
|
|
def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
|
|
|
""" EfficientNet-CondConv-B1 w/ 8 Experts """
|
|
|
@ -1242,8 +1244,6 @@ def efficientnet_lite4(pretrained=False, **kwargs):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def efficientnet_b1_pruned(pretrained=False, **kwargs):
|
|
|
|
def efficientnet_b1_pruned(pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
|
|
|
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
|
|
|
@ -1275,8 +1275,6 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
|
|
|
def tf_efficientnet_b0(pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet-B0. Tensorflow compatible variant """
|
|
|
|
""" EfficientNet-B0. Tensorflow compatible variant """
|
|
|
@ -1619,6 +1617,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
|
|
|
pretrained=pretrained, **kwargs)
|
|
|
|
pretrained=pretrained, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
|
|
|
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
|
|
|
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
|
|
|
|
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
|
|
|
@ -1764,4 +1763,3 @@ def tf_mixnet_l(pretrained=False, **kwargs):
|
|
|
|
model = _gen_mixnet_m(
|
|
|
|
model = _gen_mixnet_m(
|
|
|
|
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
|
|
|
|
'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|