diff --git a/timm/models/densenet.py b/timm/models/densenet.py index a46b86ad..1afdfd7b 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict): _version = 2 def __init__( - self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): @@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None): super(DenseTransition, self).__init__() self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d(