parent
fa2e5c6f16
commit
c14fab4710
@ -0,0 +1,8 @@
|
||||
model: 'tf_efficientnet_b0' # model architecture (default: dpn92)
|
||||
img_size: 224 # Input image dimension
|
||||
mean: null # Override mean pixel value of dataset
|
||||
std: null # Override std deviation of of dataset
|
||||
num_classes: 2 # Number classes in dataset
|
||||
checkpoint: 'output/train/20201124-182940-tf_efficientnet_b0-224/model_best.pth.tar' # path to latest checkpoint (default: none)
|
||||
pretrained: False # use pre-trained model
|
||||
num_gpu: 1 # Number of GPUS to use
|
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
"""PyTorch Evaluation Script
|
||||
|
||||
An example evaluation script that outputs results of model evaluation.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
|
||||
--- Usage: ---
|
||||
|
||||
model = ClassificationModel('configs/eval.yaml')
|
||||
img = Image.open("image.jpg")
|
||||
out = model.eval(img)
|
||||
print(out)
|
||||
|
||||
"""
|
||||
import yaml
|
||||
from fire import Fire
|
||||
from addict import Dict
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from timm.models import create_model
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def _update_config(config, params):
|
||||
for k, v in params.items():
|
||||
*path, key = k.split(".")
|
||||
config.update({k: v})
|
||||
print(f"Overwriting {k} = {v} (was {config.get(key)})")
|
||||
return config
|
||||
|
||||
|
||||
def _fit(config_path, **kwargs):
|
||||
with open(config_path) as stream:
|
||||
base_config = yaml.safe_load(stream)
|
||||
|
||||
if "config" in kwargs.keys():
|
||||
cfg_path = kwargs["config"]
|
||||
with open(cfg_path) as cfg:
|
||||
cfg_yaml = yaml.load(cfg, Loader=yaml.FullLoader)
|
||||
|
||||
merged_cfg = _update_config(base_config, cfg_yaml)
|
||||
else:
|
||||
merged_cfg = base_config
|
||||
|
||||
update_cfg = _update_config(merged_cfg, kwargs)
|
||||
return update_cfg
|
||||
|
||||
|
||||
def _parse_args(config_path):
|
||||
args = Dict(Fire(_fit(config_path)))
|
||||
|
||||
# Cache the args as a text string to save them in the output dir later
|
||||
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
||||
return args, args_text
|
||||
|
||||
|
||||
class ClassificationModel:
|
||||
def __init__(self, config_path: str):
|
||||
self.args, self.args_text = _parse_args(config_path)
|
||||
|
||||
# might as well try to do something useful...
|
||||
self.args.pretrained = self.args.pretrained or not self.args.checkpoint
|
||||
|
||||
# create model
|
||||
self.model = create_model(
|
||||
self.args.model,
|
||||
num_classes=self.args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=self.args.pretrained,
|
||||
checkpoint_path=self.args.checkpoint)
|
||||
self.softmax = torch.nn.Softmax(dim=1)
|
||||
|
||||
mean = self.args.mean if self.args.mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
std = self.args.std if self.args.std is not None else IMAGENET_DEFAULT_STD
|
||||
self.loader = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Resize(self.args.img_size),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std)),
|
||||
])
|
||||
|
||||
if self.args.num_gpu > 1:
|
||||
self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.args.num_gpu))).cuda()
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
# self.model = self.model.cpu()
|
||||
self.model.eval()
|
||||
|
||||
def eval(self, input):
|
||||
with torch.no_grad():
|
||||
# for OpenCV input
|
||||
# input = Image.fromarray(np.uint8(input)).convert('RGB')
|
||||
input = self.loader(input).float()
|
||||
input = input.cuda()
|
||||
|
||||
labels = self.model(input[None, ...])
|
||||
labels = self.softmax(labels)
|
||||
labels = labels.cpu()
|
||||
return labels.numpy()
|
Loading…
Reference in new issue