diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 7c64df0b..df1a1ef7 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -209,6 +209,7 @@ def push_to_hf_hub( private: bool = False, create_pr: bool = False, model_config: Optional[dict] = None, + model_card: Optional[dict] = None, ): # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -232,9 +233,23 @@ def push_to_hf_hub( # Add readme if it does not exist if not has_readme: + model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' + readme_text = "---\n" + readme_text += "tags:\n- image-classification\n- timm\n" + readme_text += "library_tag: timm\n" + readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + readme_text += f"- **{k}:** {v}\n" + if 'citation' in model_card: + readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n" readme_path.write_text(readme_text) # Upload model and return