|
|
|
@ -12,7 +12,7 @@ import torch.nn.parallel
|
|
|
|
|
import torch.utils.data as data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from models import model_factory
|
|
|
|
|
from models import create_model, transforms_imagenet_eval
|
|
|
|
|
from dataset import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
|
|
|
|
|
metavar='N', help='Input image dimension')
|
|
|
|
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
|
|
|
|
metavar='N', help='print frequency (default: 10)')
|
|
|
|
|
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
|
|
|
|
help='use pre-trained model')
|
|
|
|
|
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
|
|
|
|
|
help='use multiple-gpus')
|
|
|
|
|
parser.add_argument('--num-gpu', type=int, default=1,
|
|
|
|
|
help='Number of GPUS to use')
|
|
|
|
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
|
|
|
|
help='disable test time pool for DPN models')
|
|
|
|
|
|
|
|
|
@ -48,7 +48,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
# create model
|
|
|
|
|
num_classes = 1000
|
|
|
|
|
model = model_factory.create_model(
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
@ -57,23 +57,21 @@ def main():
|
|
|
|
|
print('Model %s created, param count: %d' %
|
|
|
|
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
|
|
|
|
|
|
|
|
|
print(model)
|
|
|
|
|
|
|
|
|
|
# optionally resume from a checkpoint
|
|
|
|
|
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
|
|
|
|
|
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
|
|
|
|
|
checkpoint = torch.load(args.restore_checkpoint)
|
|
|
|
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
|
|
|
|
print("=> loading checkpoint '{}'".format(args.checkpoint))
|
|
|
|
|
checkpoint = torch.load(args.checkpoint)
|
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
|
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
|
else:
|
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
|
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
|
|
|
|
|
print("=> loaded checkpoint '{}'".format(args.checkpoint))
|
|
|
|
|
elif not args.pretrained:
|
|
|
|
|
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
|
|
|
|
|
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
|
|
if args.multi_gpu:
|
|
|
|
|
model = torch.nn.DataParallel(model).cuda()
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
|
else:
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
|
|
|
|
@ -82,13 +80,9 @@ def main():
|
|
|
|
|
|
|
|
|
|
cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
transforms = model_factory.get_transforms_eval(
|
|
|
|
|
args.model,
|
|
|
|
|
args.img_size)
|
|
|
|
|
|
|
|
|
|
dataset = Dataset(
|
|
|
|
|
args.data,
|
|
|
|
|
transforms)
|
|
|
|
|
transforms_imagenet_eval(args.model, args.img_size))
|
|
|
|
|
|
|
|
|
|
loader = data.DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|