diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 00000000..aef709fc --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,8 @@ +model: 'tf_efficientnet_b0' # model architecture (default: dpn92) +img_size: 224 # Input image dimension +mean: null # Override mean pixel value of dataset +std: null # Override std deviation of of dataset +num_classes: 2 # Number classes in dataset +checkpoint: 'output/train/20201124-182940-tf_efficientnet_b0-224/model_best.pth.tar' # path to latest checkpoint (default: none) +pretrained: False # use pre-trained model +num_gpu: 1 # Number of GPUS to use diff --git a/eval.py b/eval.py new file mode 100644 index 00000000..da5e5d89 --- /dev/null +++ b/eval.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +"""PyTorch Evaluation Script + +An example evaluation script that outputs results of model evaluation. + +Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) + +--- Usage: --- + +model = ClassificationModel('configs/eval.yaml') +img = Image.open("image.jpg") +out = model.eval(img) +print(out) + +""" +import yaml +from fire import Fire +from addict import Dict + +import torch +from torchvision import transforms + +from timm.models import create_model +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +torch.backends.cudnn.benchmark = True + + +def _update_config(config, params): + for k, v in params.items(): + *path, key = k.split(".") + config.update({k: v}) + print(f"Overwriting {k} = {v} (was {config.get(key)})") + return config + + +def _fit(config_path, **kwargs): + with open(config_path) as stream: + base_config = yaml.safe_load(stream) + + if "config" in kwargs.keys(): + cfg_path = kwargs["config"] + with open(cfg_path) as cfg: + cfg_yaml = yaml.load(cfg, Loader=yaml.FullLoader) + + merged_cfg = _update_config(base_config, cfg_yaml) + else: + merged_cfg = base_config + + update_cfg = _update_config(merged_cfg, kwargs) + return update_cfg + + +def _parse_args(config_path): + args = Dict(Fire(_fit(config_path))) + + # Cache the args as a text string to save them in the output dir later + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) + return args, args_text + + +class ClassificationModel: + def __init__(self, config_path: str): + self.args, self.args_text = _parse_args(config_path) + + # might as well try to do something useful... + self.args.pretrained = self.args.pretrained or not self.args.checkpoint + + # create model + self.model = create_model( + self.args.model, + num_classes=self.args.num_classes, + in_chans=3, + pretrained=self.args.pretrained, + checkpoint_path=self.args.checkpoint) + self.softmax = torch.nn.Softmax(dim=1) + + mean = self.args.mean if self.args.mean is not None else IMAGENET_DEFAULT_MEAN + std = self.args.std if self.args.std is not None else IMAGENET_DEFAULT_STD + self.loader = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize(self.args.img_size), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)), + ]) + + if self.args.num_gpu > 1: + self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.args.num_gpu))).cuda() + else: + self.model = self.model.cuda() + # self.model = self.model.cpu() + self.model.eval() + + def eval(self, input): + with torch.no_grad(): + # for OpenCV input + # input = Image.fromarray(np.uint8(input)).convert('RGB') + input = self.loader(input).float() + input = input.cuda() + + labels = self.model(input[None, ...]) + labels = self.softmax(labels) + labels = labels.cpu() + return labels.numpy()