From 0e6ad2cb7aafb8c50e2e6b10e51e2d60a4bbe304 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 5 Jan 2023 17:50:11 -0800 Subject: [PATCH] Working on improved model card template for push_to_hf_hub --- timm/models/_hub.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 7c64df0b..df1a1ef7 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -209,6 +209,7 @@ def push_to_hf_hub( private: bool = False, create_pr: bool = False, model_config: Optional[dict] = None, + model_card: Optional[dict] = None, ): # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -232,9 +233,23 @@ def push_to_hf_hub( # Add readme if it does not exist if not has_readme: + model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' + readme_text = "---\n" + readme_text += "tags:\n- image-classification\n- timm\n" + readme_text += "library_tag: timm\n" + readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + readme_text += f"- **{k}:** {v}\n" + if 'citation' in model_card: + readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n" readme_path.write_text(readme_text) # Upload model and return