You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/eval.py

106 lines
3.0 KiB

#!/usr/bin/env python
"""PyTorch Evaluation Script
An example evaluation script that outputs results of model evaluation.
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
--- Usage: ---
model = ClassificationModel()
img = Image.open("image.jpg")
out = model.eval(img)
print(out)
"""
import yaml
from fire import Fire
from addict import Dict
import torch
from torchvision import transforms
from timm.models import create_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
torch.backends.cudnn.benchmark = True
def _update_config(config, params):
for k, v in params.items():
*path, key = k.split(".")
config.update({k: v})
print(f"Overwriting {k} = {v} (was {config.get(key)})")
return config
def _fit(**kwargs):
with open('configs/eval.yaml') as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
cfg_path = kwargs["config"]
with open(cfg_path) as cfg:
cfg_yaml = yaml.load(cfg, Loader=yaml.FullLoader)
merged_cfg = _update_config(base_config, cfg_yaml)
else:
merged_cfg = base_config
update_cfg = _update_config(merged_cfg, kwargs)
return update_cfg
def _parse_args():
args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
class ClassificationModel:
def __init__(self):
self.args, self.args_text = _parse_args()
# might as well try to do something useful...
self.args.pretrained = self.args.pretrained or not self.args.checkpoint
# create model
self.model = create_model(
self.args.model,
num_classes=self.args.num_classes,
in_chans=3,
pretrained=self.args.pretrained,
checkpoint_path=self.args.checkpoint)
self.softmax = torch.nn.Softmax(dim=1)
mean = self.args.mean if self.args.mean is not None else IMAGENET_DEFAULT_MEAN
std = self.args.std if self.args.std is not None else IMAGENET_DEFAULT_STD
self.loader = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.args.img_size),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)),
])
if self.args.num_gpu > 1:
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.args.num_gpu))).cuda()
else:
self.model = self.model.cuda()
# self.model = self.model.cpu()
self.model.eval()
def eval(self, input):
with torch.no_grad():
# for OpenCV input
# input = Image.fromarray(np.uint8(input)).convert('RGB')
input = self.loader(input).float()
input = input.cuda()
labels = self.model(input[None, ...])
labels = self.softmax(labels)
labels = labels.cpu()
return labels.numpy()