diff --git a/models/adaptive_avgmax_pool.py b/models/adaptive_avgmax_pool.py index 2672fb0c..9dee407f 100644 --- a/models/adaptive_avgmax_pool.py +++ b/models/adaptive_avgmax_pool.py @@ -49,7 +49,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1): return x -class AdaptiveAvgMaxPool2d(torch.nn.Module): +class AdaptiveAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveAvgMaxPool2d, self).__init__() self.output_size = output_size @@ -58,7 +58,7 @@ class AdaptiveAvgMaxPool2d(torch.nn.Module): return adaptive_avgmax_pool2d(x, self.output_size) -class AdaptiveCatAvgMaxPool2d(torch.nn.Module): +class AdaptiveCatAvgMaxPool2d(nn.Module): def __init__(self, output_size=1): super(AdaptiveCatAvgMaxPool2d, self).__init__() self.output_size = output_size @@ -67,7 +67,7 @@ class AdaptiveCatAvgMaxPool2d(torch.nn.Module): return adaptive_catavgmax_pool2d(x, self.output_size) -class SelectAdaptivePool2d(torch.nn.Module): +class SelectAdaptivePool2d(nn.Module): """Selectable global pooling layer with dynamic input kernel size """ def __init__(self, output_size=1, pool_type='avg'): diff --git a/models/mnasnet.py b/models/mnasnet.py index 1a2a2990..133e654e 100644 --- a/models/mnasnet.py +++ b/models/mnasnet.py @@ -185,7 +185,6 @@ class MnasBlock(nn.Module): # Pointwise projection x = self.conv_project(x) x = self.bn2(x) - # Residual if self.has_residual: return x + residual else: @@ -268,7 +267,7 @@ class MnasNet(nn.Module): x = self.bn1(x) x = self.act_fn(x) if pool: - x = self.avg_pool(x) + x = self.global_pool(x) x = x.view(x.size(0), -1) return x diff --git a/train.py b/train.py index d81400cf..03076d88 100644 --- a/train.py +++ b/train.py @@ -156,6 +156,9 @@ def main(): global_pool=args.gp, checkpoint_path=args.initial_checkpoint) + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) # optionally resume from a checkpoint @@ -178,7 +181,7 @@ def main(): optimizer.load_state_dict(optimizer_state) if has_apex and args.amp: - model, optimizer = amp.initialize(model, optimizer, opt_level='O3') + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') use_amp = True print('AMP enabled') else: