@ -14,7 +14,7 @@ from torch.jit.annotations import List
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import load_pretrained
from . layers import SelectAdaptivePool2d , BatchNormAct2d , create_norm_act
from . layers import SelectAdaptivePool2d , BatchNormAct2d , create_norm_act , BlurPool2d
from . registry import register_model
__all__ = [ ' DenseNet ' ]
@ -71,9 +71,9 @@ class DenseLayer(nn.Module):
def call_checkpoint_bottleneck ( self , x ) :
# type: (List[torch.Tensor]) -> torch.Tensor
def closure ( * xs ) :
return self . bottleneck_fn ( * xs )
return self . bottleneck_fn ( xs )
return cp . checkpoint ( closure , x )
return cp . checkpoint ( closure , * x )
@torch.jit._overload_method # noqa: F811
def forward ( self , x ) :
@ -132,11 +132,14 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition ( nn . Sequential ) :
def __init__ ( self , num_input_features , num_output_features , norm_act_layer = nn . BatchNorm2d ):
def __init__ ( self , num_input_features , num_output_features , norm_act_layer = nn . BatchNorm2d , aa_layer = None ):
super ( DenseTransition , self ) . __init__ ( )
self . add_module ( ' norm ' , norm_act_layer ( num_input_features ) )
self . add_module ( ' conv ' , nn . Conv2d (
num_input_features , num_output_features , kernel_size = 1 , stride = 1 , bias = False ) )
if aa_layer is not None :
self . add_module ( ' pool ' , aa_layer ( num_output_features , stride = 2 ) )
else :
self . add_module ( ' pool ' , nn . AvgPool2d ( kernel_size = 2 , stride = 2 ) )
@ -301,6 +304,17 @@ def densenet121(pretrained=False, **kwargs):
return model
@register_model
def densenetblur121d ( pretrained = False , * * kwargs ) :
r """ Densenet-121 model from
` " Densely Connected Convolutional Networks " < https : / / arxiv . org / pdf / 1608.06993 . pdf > `
"""
model = _densenet (
' densenet121 ' , growth_rate = 32 , block_config = ( 6 , 12 , 24 , 16 ) , pretrained = pretrained , stem_type = ' deep ' ,
aa_layer = BlurPool2d , * * kwargs )
return model
@register_model
def densenet121d ( pretrained = False , * * kwargs ) :
r """ Densenet-121 model from