|
|
|
@ -14,7 +14,7 @@ import torch.nn.parallel
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
|
|
|
|
from timm.data import Dataset, create_loader, resolve_data_config
|
|
|
|
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
|
|
|
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
@ -24,7 +24,7 @@ parser.add_argument('data', metavar='DIR',
|
|
|
|
|
help='path to dataset')
|
|
|
|
|
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
|
|
|
|
help='model architecture (default: dpn92)')
|
|
|
|
|
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
|
|
|
|
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
|
|
|
|
help='number of data loading workers (default: 2)')
|
|
|
|
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
|
|
|
metavar='N', help='mini-batch size (default: 256)')
|
|
|
|
@ -91,9 +91,14 @@ def validate(args):
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
|
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
|
|
|
|
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing)
|
|
|
|
|
else:
|
|
|
|
|
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing)
|
|
|
|
|
|
|
|
|
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
|
|
|
|
loader = create_loader(
|
|
|
|
|
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
|
|
|
|
dataset,
|
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
use_prefetcher=args.prefetcher,
|
|
|
|
|