Actually load the model

pull/440/head
Sylvain Gugger 4 years ago
parent 7dafa71c0e
commit 5482d2e501

@ -55,10 +55,10 @@ def create_model(
try: try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
create_fn = model_entrypoint(model_cfg.pop("architecture")) create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, **kwargs) model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs)
# Probably need some extra stuff, but this is a PoC of how the config in the model hub # Probably need some extra stuff, but this is a PoC of how the config in the model hub
# could overwrite the default config values. # could overwrite the default config values.
model.default_cfg.update(model_cfg) # model.default_cfg.update(model_cfg)
except Exception as e: except Exception as e:
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name) raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)

Loading…
Cancel
Save