fixing config reader

pull/290/head
romamartyanov 5 years ago
parent c14fab4710
commit ff1a657f57

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
--- Usage: ---
model = ClassificationModel('configs/eval.yaml')
model = ClassificationModel()
img = Image.open("image.jpg")
out = model.eval(img)
print(out)
@ -34,8 +34,8 @@ def _update_config(config, params):
return config
def _fit(config_path, **kwargs):
with open(config_path) as stream:
def _fit(**kwargs):
with open('configs/eval.yaml') as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
@ -51,8 +51,8 @@ def _fit(config_path, **kwargs):
return update_cfg
def _parse_args(config_path):
args = Dict(Fire(_fit(config_path)))
def _parse_args():
args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -60,8 +60,8 @@ def _parse_args(config_path):
class ClassificationModel:
def __init__(self, config_path: str):
self.args, self.args_text = _parse_args(config_path)
def __init__(self):
self.args, self.args_text = _parse_args()
# might as well try to do something useful...
self.args.pretrained = self.args.pretrained or not self.args.checkpoint

@ -31,8 +31,8 @@ def _update_config(config, params):
return config
def _fit(config_path, **kwargs):
with open(config_path) as stream:
def _fit(**kwargs):
with open('configs/inference.yaml') as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
@ -48,8 +48,8 @@ def _fit(config_path, **kwargs):
return update_cfg
def _parse_args(config_path):
args = Dict(Fire(_fit(config_path)))
def _parse_args():
args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -58,7 +58,7 @@ def _parse_args(config_path):
def main():
setup_default_logging()
args, args_text = _parse_args('configs/inference.yaml')
args, args_text = _parse_args()
# might as well try to do something useful...
args.pretrained = args.pretrained or not args.checkpoint

@ -66,8 +66,8 @@ def _update_config(config, params):
return config
def _fit(config_path, **kwargs):
with open(config_path) as stream:
def _fit(**kwargs):
with open('configs/train.yaml') as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
@ -83,8 +83,8 @@ def _fit(config_path, **kwargs):
return update_cfg
def _parse_args(config_path):
args = Dict(Fire(_fit(config_path)))
def _parse_args():
args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -103,7 +103,7 @@ def set_deterministic(seed=42, precision=13):
def main():
setup_default_logging()
args, args_text = _parse_args('configs/train.yaml')
args, args_text = _parse_args()
set_deterministic(args.seed)
args.prefetcher = not args.no_prefetcher

@ -51,8 +51,8 @@ def _update_config(config, params):
return config
def _fit(config_path, **kwargs):
with open(config_path) as stream:
def _fit(**kwargs):
with open('configs/validate.yaml') as stream:
base_config = yaml.safe_load(stream)
if "config" in kwargs.keys():
@ -68,8 +68,8 @@ def _fit(config_path, **kwargs):
return update_cfg
def _parse_args(config_path):
args = Dict(Fire(_fit(config_path)))
def _parse_args():
args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -233,7 +233,7 @@ def validate(args):
def main():
setup_default_logging()
args, args_text = _parse_args('configs/validate.yaml')
args, args_text = _parse_args()
model_cfgs = []
model_names = []

Loading…
Cancel
Save