Fixup validate/inference script args, fix senet init for better test accuracy

pull/1/head
Ross Wightman 6 years ago
parent b1a5a71151
commit 31055466fc

@ -10,10 +10,9 @@ import time
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
import torch.autograd as autograd
import torch.utils.data as data import torch.utils.data as data
import model_factory from models import create_model, transforms_imagenet_eval
from dataset import Dataset from dataset import Dataset
@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', parser.add_argument('--num-gpu', type=int, default=1,
help='use multiple-gpus') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false', parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
help='use pre-trained model') help='use pre-trained model')
@ -47,37 +46,33 @@ def main():
# create model # create model
num_classes = 1000 num_classes = 1000
model = model_factory.create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained,
test_time_pool=args.test_time_pool) test_time_pool=args.test_time_pool)
# resume from a checkpoint # resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.restore_checkpoint) checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained: elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1) exit(1)
if args.multi_gpu: if args.num_gpu > 1:
model = torch.nn.DataParallel(model).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model = model.cuda() model = model.cuda()
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset( dataset = Dataset(
args.data, args.data,
transforms) transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader( loader = data.DataLoader(
dataset, dataset,

@ -105,14 +105,10 @@ pretrained_config = {
def _weight_init(m, n='', ll=''): def _weight_init(m, n='', ll=''):
print(m, n, ll)
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
if ll and n == ll: nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.) nn.init.constant_(m.bias, 0.)
@ -128,9 +124,6 @@ class SEModule(nn.Module):
channels // reduction, channels, kernel_size=1, padding=0) channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
for m in self.modules():
_weight_init(m)
def forward(self, x): def forward(self, x):
module_input = x module_input = x
x = self.avg_pool(x) x = self.avg_pool(x)
@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBottleneck(Bottleneck): class SEResNetBottleneck(Bottleneck):
""" """
@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNeXtBottleneck(Bottleneck): class SEResNeXtBottleneck(Bottleneck):
""" """
@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBlock(nn.Module): class SEResNetBlock(nn.Module):
expansion = 1 expansion = 1
@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn2')
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -405,11 +386,8 @@ class SENet(nn.Module):
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes) self.last_linear = nn.Linear(512 * block.expansion, num_classes)
for n, m in self.named_children(): for m in self.modules():
if n == 'layer0': _weight_init(m)
m.apply(_weight_init)
else:
_weight_init(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
downsample_kernel_size=1, downsample_padding=0): downsample_kernel_size=1, downsample_padding=0):

@ -21,7 +21,7 @@ class LeNormalize(object):
return tensor return tensor
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)): def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)):
if 'dpn' in model_name: if 'dpn' in model_name:
normalize = transforms.Normalize( normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN, mean=IMAGENET_DPN_MEAN,

@ -180,8 +180,8 @@ def main():
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
exit(1) exit(1)
if optimizer_state is not None: #if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state) # optimizer.load_state_dict(optimizer_state)
if args.sched == 'cosine': if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler( lr_scheduler = scheduler.CosineLRScheduler(

@ -12,7 +12,7 @@ import torch.nn.parallel
import torch.utils.data as data import torch.utils.data as data
from models import model_factory from models import create_model, transforms_imagenet_eval
from dataset import Dataset from dataset import Dataset
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', parser.add_argument('--num-gpu', type=int, default=1,
help='use multiple-gpus') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool for DPN models') help='disable test time pool for DPN models')
@ -48,7 +48,7 @@ def main():
# create model # create model
num_classes = 1000 num_classes = 1000
model = model_factory.create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained,
@ -57,23 +57,21 @@ def main():
print('Model %s created, param count: %d' % print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
print(model)
# optionally resume from a checkpoint # optionally resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.restore_checkpoint) checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained: elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1) exit(1)
if args.multi_gpu: if args.num_gpu > 1:
model = torch.nn.DataParallel(model).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model = model.cuda() model = model.cuda()
@ -82,13 +80,9 @@ def main():
cudnn.benchmark = True cudnn.benchmark = True
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset( dataset = Dataset(
args.data, args.data,
transforms) transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader( loader = data.DataLoader(
dataset, dataset,

Loading…
Cancel
Save