More models in sotabench, more control over sotabench run, dataset filename extraction consistency

pull/244/head
Ross Wightman 4 years ago
parent 9c406532bd
commit e8ca45854c

@ -73,7 +73,7 @@ def main():
(args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, config, args)
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
@ -115,9 +115,8 @@ def main():
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
filenames = loader.dataset.filenames()
filenames = loader.dataset.filenames(basename=True)
for filename, label in zip(filenames, topk_ids):
filename = os.path.basename(filename)
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
filename, label[0], label[1], label[2], label[3], label[4]))

@ -1,8 +1,10 @@
import torch
from torchbench.image_classification import ImageNet
from sotabencheval.image_classification import ImageNetEvaluator
from sotabencheval.utils import is_server
from timm import create_model
from timm.data import resolve_data_config, create_transform
from timm.models import TestTimePoolHead
from timm.data import resolve_data_config, create_loader, DatasetTar
from timm.models import apply_test_time_pool
from tqdm import tqdm
import os
NUM_GPU = 1
@ -148,6 +150,10 @@ model_list = [
_entry('ese_vovnet19b_dw', 'VoVNet-19-DW-V2', '1911.06667'),
_entry('ese_vovnet39b', 'VoVNet-39-V2', '1911.06667'),
_entry('cspresnet50', 'CSPResNet-50', '1911.11929'),
_entry('cspresnext50', 'CSPResNeXt-50', '1911.11929'),
_entry('cspdarknet53', 'CSPDarkNet-53', '1911.11929'),
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
model_desc='Ported from official Google AI Tensorflow weights'),
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',
@ -448,8 +454,20 @@ model_list = [
_entry('regnety_160', 'RegNetY-16GF', '2003.13678'),
_entry('regnety_320', 'RegNetY-32GF', '2003.13678', batch_size=BATCH_SIZE // 2),
_entry('rexnet_100', 'ReXNet-1.0x', '2007.00992'),
_entry('rexnet_130', 'ReXNet-1.3x', '2007.00992'),
_entry('rexnet_150', 'ReXNet-1.5x', '2007.00992'),
_entry('rexnet_200', 'ReXNet-2.0x', '2007.00992'),
]
if is_server():
DATA_ROOT = './.data/vision/imagenet'
else:
# local settings
DATA_ROOT = './'
DATA_FILENAME = 'ILSVRC2012_img_val.tar'
TAR_PATH = os.path.join(DATA_ROOT, DATA_FILENAME)
for m in model_list:
model_name = m['model']
# create model from name
@ -457,25 +475,60 @@ for m in model_list:
param_count = sum([m.numel() for m in model.parameters()])
print('Model %s, %s created. Param count: %d' % (model_name, m['paper_model_name'], param_count))
dataset = DatasetTar(TAR_PATH)
filenames = [os.path.splitext(f)[0] for f in dataset.filenames()]
# get appropriate transform for model's default pretrained config
data_config = resolve_data_config(m['args'], model=model, verbose=True)
test_time_pool = False
if m['ttp']:
model = TestTimePoolHead(model, model.default_cfg['pool_size'])
model, test_time_pool = apply_test_time_pool(model, data_config)
data_config['crop_pct'] = 1.0
input_transform = create_transform(**data_config)
# Run the benchmark
ImageNet.benchmark(
model=model,
model_description=m.get('model_description', None),
paper_model_name=m['paper_model_name'],
batch_size = m['batch_size']
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=batch_size,
use_prefetcher=True,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=6,
crop_pct=data_config['crop_pct'],
pin_memory=True)
evaluator = ImageNetEvaluator(
root=DATA_ROOT,
model_name=m['paper_model_name'],
paper_arxiv_id=m['paper_arxiv_id'],
input_transform=input_transform,
batch_size=m['batch_size'],
num_gpu=NUM_GPU,
data_root=os.environ.get('IMAGENET_DIR', './.data/vision/imagenet')
model_description=m.get('model_description', None),
)
model.cuda()
model.eval()
with torch.no_grad():
# warmup
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
model(input)
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)
evaluator.reset_time()
sample_count = 0
for input, target in loader:
output = model(input)
num_samples = len(output)
image_ids = [filenames[i] for i in range(sample_count, sample_count + num_samples)]
output = output.cpu().numpy()
evaluator.add(dict(zip(image_ids, list(output))))
sample_count += num_samples
bar.update(num_samples)
bar.close()
evaluator.save()
for k, v in evaluator.results.items():
print(k, v)
for k, v in evaluator.speed_mem_metrics.items():
print(k, v)
torch.cuda.empty_cache()

@ -3,10 +3,11 @@ source /workspace/venv/bin/activate
pip install -r requirements-sotabench.txt
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
pip uninstall -y pillow
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
apt-get install wget
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet
#wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet

@ -94,17 +94,21 @@ class Dataset(data.Dataset):
def __len__(self):
return len(self.samples)
def filenames(self, indices=[], basename=False):
if indices:
if basename:
return [os.path.basename(self.samples[i][0]) for i in indices]
else:
return [self.samples[i][0] for i in indices]
else:
if basename:
return [os.path.basename(x[0]) for x in self.samples]
else:
return [x[0] for x in self.samples]
def filename(self, index, basename=False, absolute=False):
filename = self.samples[index][0]
if basename:
filename = os.path.basename(filename)
elif not absolute:
filename = os.path.relpath(filename, self.root)
return filename
def filenames(self, basename=False, absolute=False):
fn = lambda x: x
if basename:
fn = os.path.basename
elif not absolute:
fn = lambda x: os.path.relpath(x, self.root)
return [fn(x[0]) for x in self.samples]
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
@ -160,6 +164,16 @@ class DatasetTar(data.Dataset):
def __len__(self):
return len(self.samples)
def filename(self, index, basename=False):
filename = self.samples[index][0].name
if basename:
filename = os.path.basename(filename)
return filename
def filenames(self, basename=False):
fn = os.path.basename if basename else lambda x: x
return [fn(x[0].name) for x in self.samples]
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""

@ -36,13 +36,12 @@ class TestTimePoolHead(nn.Module):
return x.view(x.size(0), -1)
def apply_test_time_pool(model, config, args):
def apply_test_time_pool(model, config):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False
if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and
config['input_size'][-2] > model.default_cfg['input_size'][-2]):
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])

@ -166,6 +166,7 @@ class ReXNetV1(nn.Module):
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate
self.num_classes = num_classes
assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32

@ -139,7 +139,7 @@ def validate(args):
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args)
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, data_config)
if args.torchscript:
torch.jit.optimized_execution(True)

Loading…
Cancel
Save