From 1d8d6f6072659e905d91a2b297d53e927853457d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 15:00:35 -0700 Subject: [PATCH] Fix two default args in DenseNet blocks... fix #1427 --- timm/models/densenet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index a46b86ad..1afdfd7b 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict): _version = 2 def __init__( - self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, drop_rate=0., memory_efficient=False): super(DenseBlock, self).__init__() for i in range(num_layers): @@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None): super(DenseTransition, self).__init__() self.add_module('norm', norm_layer(num_input_features)) self.add_module('conv', nn.Conv2d(