Use the model hub config

pull/440/head
Sylvain Gugger 4 years ago
parent 4d0b5feed0
commit 7dafa71c0e

@ -54,11 +54,11 @@ def create_model(
else: else:
try: try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
# This does not work, but there is probably a way to have the values in the config override
# the defaults if needed.
# model_args["model_cfg"] = model_cfg
create_fn = model_entrypoint(model_cfg.pop("architecture")) create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, **kwargs) model = create_fn(**model_args, **kwargs)
# Probably need some extra stuff, but this is a PoC of how the config in the model hub
# could overwrite the default config values.
model.default_cfg.update(model_cfg)
except Exception as e: except Exception as e:
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name) raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)

Loading…
Cancel
Save