diff --git a/models/genmobilenet.py b/models/genmobilenet.py index d0665316..914e82cc 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -549,8 +549,8 @@ class InvertedResidual(nn.Module): # Squeeze-and-excitation if self.has_se: - reduce_mult = mid_chs if se_reduce_mid else in_chs - self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(reduce_mult * se_ratio)), + se_base_chs = mid_chs if se_reduce_mid else in_chs + self.se = SqueezeExcite(mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) # Point-wise linear projection