From ff1a657f57c983a2ae542e94986ed88102d3cb0e Mon Sep 17 00:00:00 2001 From: romamartyanov Date: Sun, 29 Nov 2020 22:05:46 +0300 Subject: [PATCH] fixing config reader --- eval.py | 14 +++++++------- inference.py | 10 +++++----- train.py | 10 +++++----- validate.py | 10 +++++----- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/eval.py b/eval.py index da5e5d89..db0fa825 100644 --- a/eval.py +++ b/eval.py @@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) --- Usage: --- -model = ClassificationModel('configs/eval.yaml') +model = ClassificationModel() img = Image.open("image.jpg") out = model.eval(img) print(out) @@ -34,8 +34,8 @@ def _update_config(config, params): return config -def _fit(config_path, **kwargs): - with open(config_path) as stream: +def _fit(**kwargs): + with open('configs/eval.yaml') as stream: base_config = yaml.safe_load(stream) if "config" in kwargs.keys(): @@ -51,8 +51,8 @@ def _fit(config_path, **kwargs): return update_cfg -def _parse_args(config_path): - args = Dict(Fire(_fit(config_path))) +def _parse_args(): + args = Dict(Fire(_fit)) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) @@ -60,8 +60,8 @@ def _parse_args(config_path): class ClassificationModel: - def __init__(self, config_path: str): - self.args, self.args_text = _parse_args(config_path) + def __init__(self): + self.args, self.args_text = _parse_args() # might as well try to do something useful... self.args.pretrained = self.args.pretrained or not self.args.checkpoint diff --git a/inference.py b/inference.py index 941d9ccc..c15c7b5c 100755 --- a/inference.py +++ b/inference.py @@ -31,8 +31,8 @@ def _update_config(config, params): return config -def _fit(config_path, **kwargs): - with open(config_path) as stream: +def _fit(**kwargs): + with open('configs/inference.yaml') as stream: base_config = yaml.safe_load(stream) if "config" in kwargs.keys(): @@ -48,8 +48,8 @@ def _fit(config_path, **kwargs): return update_cfg -def _parse_args(config_path): - args = Dict(Fire(_fit(config_path))) +def _parse_args(): + args = Dict(Fire(_fit)) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) @@ -58,7 +58,7 @@ def _parse_args(config_path): def main(): setup_default_logging() - args, args_text = _parse_args('configs/inference.yaml') + args, args_text = _parse_args() # might as well try to do something useful... args.pretrained = args.pretrained or not args.checkpoint diff --git a/train.py b/train.py index 8a99733b..4e90a0bc 100755 --- a/train.py +++ b/train.py @@ -66,8 +66,8 @@ def _update_config(config, params): return config -def _fit(config_path, **kwargs): - with open(config_path) as stream: +def _fit(**kwargs): + with open('configs/train.yaml') as stream: base_config = yaml.safe_load(stream) if "config" in kwargs.keys(): @@ -83,8 +83,8 @@ def _fit(config_path, **kwargs): return update_cfg -def _parse_args(config_path): - args = Dict(Fire(_fit(config_path))) +def _parse_args(): + args = Dict(Fire(_fit)) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) @@ -103,7 +103,7 @@ def set_deterministic(seed=42, precision=13): def main(): setup_default_logging() - args, args_text = _parse_args('configs/train.yaml') + args, args_text = _parse_args() set_deterministic(args.seed) args.prefetcher = not args.no_prefetcher diff --git a/validate.py b/validate.py index 35f0b9a7..968d0e5d 100755 --- a/validate.py +++ b/validate.py @@ -51,8 +51,8 @@ def _update_config(config, params): return config -def _fit(config_path, **kwargs): - with open(config_path) as stream: +def _fit(**kwargs): + with open('configs/validate.yaml') as stream: base_config = yaml.safe_load(stream) if "config" in kwargs.keys(): @@ -68,8 +68,8 @@ def _fit(config_path, **kwargs): return update_cfg -def _parse_args(config_path): - args = Dict(Fire(_fit(config_path))) +def _parse_args(): + args = Dict(Fire(_fit)) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) @@ -233,7 +233,7 @@ def validate(args): def main(): setup_default_logging() - args, args_text = _parse_args('configs/validate.yaml') + args, args_text = _parse_args() model_cfgs = [] model_names = []