Make download code robust to symlinks.

coremltools seems to not like symlinks.
pull/10/head
Pedro Cuenca 1 year ago
parent 21a6f06eaa
commit a9475c1086

@ -46,18 +46,37 @@ Python | macOS | Xcode | iPadOS, iOS |
If you want to use any of those models you may download the weights and proceed to [generate images with Python](#image-generation-with-python) or [Swift](#image-generation-with-swift).
There are several variants in each model repository. You may clone the whole repos using `git` and `git lfs`, or select just the variants you need. For example, to do generation in Python using the `ORIGINAL` attention implementation (read [this section](#converting-models-to-core-ml) for details), you could do something like this:
There are several variants in each model repository. You may clone the whole repos using `git` and `git lfs`, or select just the variants you need. For example, to do generation in Python using the `ORIGINAL` attention implementation (read [this section](#converting-models-to-core-ml) for details), you could use the following helper code:
```Python
from huggingface_hub import snapshot_download
from huggingface_hub.file_download import repo_folder_name
from pathlib import Path
import shutil
repo_id = "apple/coreml-stable-diffusion-v1-4"
variant = "original/packages"
downloaded = snapshot_download(repo_id, allow_patterns=f"{variant}/*")
def download_model(repo_id, variant, output_dir):
destination = Path(output_dir) / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
if destination.exists():
raise Exception(f"Model already exists at {destination}")
# Download and copy without symlinks
downloaded = snapshot_download(repo_id, allow_patterns=f"{variant}/*", cache_dir=output_dir)
downloaded_bundle = Path(downloaded) / variant
shutil.copytree(downloaded_bundle, destination)
# Remove all downloaded files
cache_folder = Path(output_dir) / repo_folder_name(repo_id=repo_id, repo_type="model")
shutil.rmtree(cache_folder)
return destination
model_path = download_model(repo_id, variant, output_dir="./models")
print(f"Model downloaded at {model_path}")
```
`downloaded` would be the path in your local filesystem where the model checkpoint was saved. Please, refer to [this post](https://huggingface.co/blog/diffusers-coreml) for additional details on this process.
`model_path` would be the path in your local filesystem where the checkpoint was saved. Please, refer to [this post](https://huggingface.co/blog/diffusers-coreml) for additional details on this process.
If you prefer to use `git` to clone the repos with all the variants, you need to follow this process:

Loading…
Cancel
Save