|
|
|
@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
|
|
|
|
|
|
|
|
|
--- Usage: ---
|
|
|
|
|
|
|
|
|
|
model = ClassificationModel('configs/eval.yaml')
|
|
|
|
|
model = ClassificationModel()
|
|
|
|
|
img = Image.open("image.jpg")
|
|
|
|
|
out = model.eval(img)
|
|
|
|
|
print(out)
|
|
|
|
@ -34,8 +34,8 @@ def _update_config(config, params):
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fit(config_path, **kwargs):
|
|
|
|
|
with open(config_path) as stream:
|
|
|
|
|
def _fit(**kwargs):
|
|
|
|
|
with open('configs/eval.yaml') as stream:
|
|
|
|
|
base_config = yaml.safe_load(stream)
|
|
|
|
|
|
|
|
|
|
if "config" in kwargs.keys():
|
|
|
|
@ -51,8 +51,8 @@ def _fit(config_path, **kwargs):
|
|
|
|
|
return update_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_args(config_path):
|
|
|
|
|
args = Dict(Fire(_fit(config_path)))
|
|
|
|
|
def _parse_args():
|
|
|
|
|
args = Dict(Fire(_fit))
|
|
|
|
|
|
|
|
|
|
# Cache the args as a text string to save them in the output dir later
|
|
|
|
|
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
|
|
|
@ -60,8 +60,8 @@ def _parse_args(config_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassificationModel:
|
|
|
|
|
def __init__(self, config_path: str):
|
|
|
|
|
self.args, self.args_text = _parse_args(config_path)
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.args, self.args_text = _parse_args()
|
|
|
|
|
|
|
|
|
|
# might as well try to do something useful...
|
|
|
|
|
self.args.pretrained = self.args.pretrained or not self.args.checkpoint
|
|
|
|
|