Fix pooling in mnasnet, more sensible default for AMP opt level

pull/1/head
Ross Wightman 6 years ago
parent 996c77aa94
commit e9c7961efc

@ -49,7 +49,7 @@ def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
return x return x
class AdaptiveAvgMaxPool2d(torch.nn.Module): class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1): def __init__(self, output_size=1):
super(AdaptiveAvgMaxPool2d, self).__init__() super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size self.output_size = output_size
@ -58,7 +58,7 @@ class AdaptiveAvgMaxPool2d(torch.nn.Module):
return adaptive_avgmax_pool2d(x, self.output_size) return adaptive_avgmax_pool2d(x, self.output_size)
class AdaptiveCatAvgMaxPool2d(torch.nn.Module): class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1): def __init__(self, output_size=1):
super(AdaptiveCatAvgMaxPool2d, self).__init__() super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size self.output_size = output_size
@ -67,7 +67,7 @@ class AdaptiveCatAvgMaxPool2d(torch.nn.Module):
return adaptive_catavgmax_pool2d(x, self.output_size) return adaptive_catavgmax_pool2d(x, self.output_size)
class SelectAdaptivePool2d(torch.nn.Module): class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size """Selectable global pooling layer with dynamic input kernel size
""" """
def __init__(self, output_size=1, pool_type='avg'): def __init__(self, output_size=1, pool_type='avg'):

@ -185,7 +185,6 @@ class MnasBlock(nn.Module):
# Pointwise projection # Pointwise projection
x = self.conv_project(x) x = self.conv_project(x)
x = self.bn2(x) x = self.bn2(x)
# Residual
if self.has_residual: if self.has_residual:
return x + residual return x + residual
else: else:
@ -268,7 +267,7 @@ class MnasNet(nn.Module):
x = self.bn1(x) x = self.bn1(x)
x = self.act_fn(x) x = self.act_fn(x)
if pool: if pool:
x = self.avg_pool(x) x = self.global_pool(x)
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
return x return x

@ -156,6 +156,9 @@ def main():
global_pool=args.gp, global_pool=args.gp,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
# optionally resume from a checkpoint # optionally resume from a checkpoint
@ -178,7 +181,7 @@ def main():
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
if has_apex and args.amp: if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O3') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
use_amp = True use_amp = True
print('AMP enabled') print('AMP enabled')
else: else:

Loading…
Cancel
Save