Fixup validate/inference script args, fix senet init for better test accuracy

pull/1/head
Ross Wightman 6 years ago
parent b1a5a71151
commit 31055466fc

@ -10,10 +10,9 @@ import time
import argparse
import numpy as np
import torch
import torch.autograd as autograd
import torch.utils.data as data
import model_factory
from models import create_model, transforms_imagenet_eval
from dataset import Dataset
@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
help='use multiple-gpus')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
help='use pre-trained model')
@ -47,37 +46,33 @@ def main():
# create model
num_classes = 1000
model = model_factory.create_model(
model = create_model(
args.model,
num_classes=num_classes,
pretrained=args.pretrained,
test_time_pool=args.test_time_pool)
# resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
checkpoint = torch.load(args.restore_checkpoint)
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1)
if args.multi_gpu:
model = torch.nn.DataParallel(model).cuda()
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model = model.cuda()
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset(
args.data,
transforms)
transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader(
dataset,

@ -105,13 +105,9 @@ pretrained_config = {
def _weight_init(m, n='', ll=''):
print(m, n, ll)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
if ll and n == ll:
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
@ -128,9 +124,6 @@ class SEModule(nn.Module):
channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
for m in self.modules():
_weight_init(m)
def forward(self, x):
module_input = x
x = self.avg_pool(x)
@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBottleneck(Bottleneck):
"""
@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNeXtBottleneck(Bottleneck):
"""
@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBlock(nn.Module):
expansion = 1
@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn2')
def forward(self, x):
residual = x
@ -405,10 +386,7 @@ class SENet(nn.Module):
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
for n, m in self.named_children():
if n == 'layer0':
m.apply(_weight_init)
else:
for m in self.modules():
_weight_init(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,

@ -21,7 +21,7 @@ class LeNormalize(object):
return tensor
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)):
if 'dpn' in model_name:
normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN,

@ -180,8 +180,8 @@ def main():
assert False and "Invalid optimizer"
exit(1)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
#if optimizer_state is not None:
# optimizer.load_state_dict(optimizer_state)
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(

@ -12,7 +12,7 @@ import torch.nn.parallel
import torch.utils.data as data
from models import model_factory
from models import create_model, transforms_imagenet_eval
from dataset import Dataset
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
help='use multiple-gpus')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool for DPN models')
@ -48,7 +48,7 @@ def main():
# create model
num_classes = 1000
model = model_factory.create_model(
model = create_model(
args.model,
num_classes=num_classes,
pretrained=args.pretrained,
@ -57,23 +57,21 @@ def main():
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
print(model)
# optionally resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
checkpoint = torch.load(args.restore_checkpoint)
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1)
if args.multi_gpu:
model = torch.nn.DataParallel(model).cuda()
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model = model.cuda()
@ -82,13 +80,9 @@ def main():
cudnn.benchmark = True
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset(
args.data,
transforms)
transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader(
dataset,

Loading…
Cancel
Save