|
|
|
@ -55,10 +55,10 @@ def create_model(
|
|
|
|
|
try:
|
|
|
|
|
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
|
|
|
|
|
create_fn = model_entrypoint(model_cfg.pop("architecture"))
|
|
|
|
|
model = create_fn(**model_args, **kwargs)
|
|
|
|
|
model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs)
|
|
|
|
|
# Probably need some extra stuff, but this is a PoC of how the config in the model hub
|
|
|
|
|
# could overwrite the default config values.
|
|
|
|
|
model.default_cfg.update(model_cfg)
|
|
|
|
|
# model.default_cfg.update(model_cfg)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)
|
|
|
|
|
|
|
|
|
|