fixing config reader

pull/290/head
romamartyanov 5 years ago
parent c14fab4710
commit ff1a657f57

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
--- Usage: --- --- Usage: ---
model = ClassificationModel('configs/eval.yaml') model = ClassificationModel()
img = Image.open("image.jpg") img = Image.open("image.jpg")
out = model.eval(img) out = model.eval(img)
print(out) print(out)
@ -34,8 +34,8 @@ def _update_config(config, params):
return config return config
def _fit(config_path, **kwargs): def _fit(**kwargs):
with open(config_path) as stream: with open('configs/eval.yaml') as stream:
base_config = yaml.safe_load(stream) base_config = yaml.safe_load(stream)
if "config" in kwargs.keys(): if "config" in kwargs.keys():
@ -51,8 +51,8 @@ def _fit(config_path, **kwargs):
return update_cfg return update_cfg
def _parse_args(config_path): def _parse_args():
args = Dict(Fire(_fit(config_path))) args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later # Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -60,8 +60,8 @@ def _parse_args(config_path):
class ClassificationModel: class ClassificationModel:
def __init__(self, config_path: str): def __init__(self):
self.args, self.args_text = _parse_args(config_path) self.args, self.args_text = _parse_args()
# might as well try to do something useful... # might as well try to do something useful...
self.args.pretrained = self.args.pretrained or not self.args.checkpoint self.args.pretrained = self.args.pretrained or not self.args.checkpoint

@ -31,8 +31,8 @@ def _update_config(config, params):
return config return config
def _fit(config_path, **kwargs): def _fit(**kwargs):
with open(config_path) as stream: with open('configs/inference.yaml') as stream:
base_config = yaml.safe_load(stream) base_config = yaml.safe_load(stream)
if "config" in kwargs.keys(): if "config" in kwargs.keys():
@ -48,8 +48,8 @@ def _fit(config_path, **kwargs):
return update_cfg return update_cfg
def _parse_args(config_path): def _parse_args():
args = Dict(Fire(_fit(config_path))) args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later # Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -58,7 +58,7 @@ def _parse_args(config_path):
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args('configs/inference.yaml') args, args_text = _parse_args()
# might as well try to do something useful... # might as well try to do something useful...
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint

@ -66,8 +66,8 @@ def _update_config(config, params):
return config return config
def _fit(config_path, **kwargs): def _fit(**kwargs):
with open(config_path) as stream: with open('configs/train.yaml') as stream:
base_config = yaml.safe_load(stream) base_config = yaml.safe_load(stream)
if "config" in kwargs.keys(): if "config" in kwargs.keys():
@ -83,8 +83,8 @@ def _fit(config_path, **kwargs):
return update_cfg return update_cfg
def _parse_args(config_path): def _parse_args():
args = Dict(Fire(_fit(config_path))) args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later # Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -103,7 +103,7 @@ def set_deterministic(seed=42, precision=13):
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args('configs/train.yaml') args, args_text = _parse_args()
set_deterministic(args.seed) set_deterministic(args.seed)
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher

@ -51,8 +51,8 @@ def _update_config(config, params):
return config return config
def _fit(config_path, **kwargs): def _fit(**kwargs):
with open(config_path) as stream: with open('configs/validate.yaml') as stream:
base_config = yaml.safe_load(stream) base_config = yaml.safe_load(stream)
if "config" in kwargs.keys(): if "config" in kwargs.keys():
@ -68,8 +68,8 @@ def _fit(config_path, **kwargs):
return update_cfg return update_cfg
def _parse_args(config_path): def _parse_args():
args = Dict(Fire(_fit(config_path))) args = Dict(Fire(_fit))
# Cache the args as a text string to save them in the output dir later # Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
@ -233,7 +233,7 @@ def validate(args):
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args('configs/validate.yaml') args, args_text = _parse_args()
model_cfgs = [] model_cfgs = []
model_names = [] model_names = []

Loading…
Cancel
Save