|
|
|
@ -54,11 +54,11 @@ def create_model(
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
|
|
|
|
|
# This does not work, but there is probably a way to have the values in the config override
|
|
|
|
|
# the defaults if needed.
|
|
|
|
|
# model_args["model_cfg"] = model_cfg
|
|
|
|
|
create_fn = model_entrypoint(model_cfg.pop("architecture"))
|
|
|
|
|
model = create_fn(**model_args, **kwargs)
|
|
|
|
|
# Probably need some extra stuff, but this is a PoC of how the config in the model hub
|
|
|
|
|
# could overwrite the default config values.
|
|
|
|
|
model.default_cfg.update(model_cfg)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)
|
|
|
|
|
|
|
|
|
|