Fix two default args in DenseNet blocks... fix #1427

pull/1415/head
Ross Wightman 2 years ago
parent 527f9a4cb2
commit 1d8d6f6072

@ -115,7 +115,7 @@ class DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__( def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False): drop_rate=0., memory_efficient=False):
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
@ -138,7 +138,7 @@ class DenseBlock(nn.ModuleDict):
class DenseTransition(nn.Sequential): class DenseTransition(nn.Sequential):
def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): def __init__(self, num_input_features, num_output_features, norm_layer=BatchNormAct2d, aa_layer=None):
super(DenseTransition, self).__init__() super(DenseTransition, self).__init__()
self.add_module('norm', norm_layer(num_input_features)) self.add_module('norm', norm_layer(num_input_features))
self.add_module('conv', nn.Conv2d( self.add_module('conv', nn.Conv2d(

Loading…
Cancel
Save