Added class for easy model evaluation on one image

pull/290/head
romamartyanov 5 years ago
parent fa2e5c6f16
commit c14fab4710

@ -0,0 +1,8 @@
model: 'tf_efficientnet_b0' # model architecture (default: dpn92)
img_size: 224 # Input image dimension
mean: null # Override mean pixel value of dataset
std: null # Override std deviation of of dataset
num_classes: 2 # Number classes in dataset
checkpoint: 'output/train/20201124-182940-tf_efficientnet_b0-224/model_best.pth.tar' # path to latest checkpoint (default: none)
pretrained: False # use pre-trained model
num_gpu: 1 # Number of GPUS to use

@ -0,0 +1,105 @@
#!/usr/bin/env python
"""PyTorch Evaluation Script
An example evaluation script that outputs results of model evaluation.
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
--- Usage: ---
model = ClassificationModel('configs/eval.yaml')
img = Image.open("image.jpg")
out = model.eval(img)
print(out)
"""
import yaml
from fire import Fire
from addict import Dict
import torch
from torchvision import transforms
from timm.models import create_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
torch.backends.cudnn.benchmark = True
def _update_config(config, params):
for k, v in params.items():
*path, key = k.split(".")
config.update({k: v})
print(f"Overwriting {k} = {v} (was {config.get(key)})")
return config
def _fit(config_path, **kwargs):
with open(config_path) as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
cfg_path = kwargs["config"]
with open(cfg_path) as cfg:
cfg_yaml = yaml.load(cfg, Loader=yaml.FullLoader)
merged_cfg = _update_config(base_config, cfg_yaml)
else:
merged_cfg = base_config
update_cfg = _update_config(merged_cfg, kwargs)
return update_cfg
def _parse_args(config_path):
args = Dict(Fire(_fit(config_path)))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
class ClassificationModel:
def __init__(self, config_path: str):
self.args, self.args_text = _parse_args(config_path)
# might as well try to do something useful...
self.args.pretrained = self.args.pretrained or not self.args.checkpoint
# create model
self.model = create_model(
self.args.model,
num_classes=self.args.num_classes,
in_chans=3,
pretrained=self.args.pretrained,
checkpoint_path=self.args.checkpoint)
self.softmax = torch.nn.Softmax(dim=1)
mean = self.args.mean if self.args.mean is not None else IMAGENET_DEFAULT_MEAN
std = self.args.std if self.args.std is not None else IMAGENET_DEFAULT_STD
self.loader = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(self.args.img_size),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std)),
])
if self.args.num_gpu > 1:
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.args.num_gpu))).cuda()
else:
self.model = self.model.cuda()
# self.model = self.model.cpu()
self.model.eval()
def eval(self, input):
with torch.no_grad():
# for OpenCV input
# input = Image.fromarray(np.uint8(input)).convert('RGB')
input = self.loader(input).float()
input = input.cuda()
labels = self.model(input[None, ...])
labels = self.softmax(labels)
labels = labels.cpu()
return labels.numpy()
Loading…
Cancel
Save