@ -4,8 +4,8 @@ from pathlib import Path
import shutil
import shutil
# From apple: https://huggingface.co/apple
# From apple: https://huggingface.co/apple
# repo_id = "apple/coreml-stable-diffusion-v1-5"
# repo_id = "apple/coreml-stable-diffusion-v1-4"
# repo_id = "apple/coreml-stable-diffusion-v1-4"
# repo_id = "apple/coreml-stable-diffusion-v1-5"
# repo_id = "apple/coreml-stable-diffusion-2-base"
# repo_id = "apple/coreml-stable-diffusion-2-base"
# For Swift
# For Swift
@ -16,20 +16,27 @@ import shutil
# From coreml: https://huggingface.co/coreml
# From coreml: https://huggingface.co/coreml
repo_id = " coreml/coreml-stable-diffusion-2-1-base "
repo_id = " coreml/coreml-stable-diffusion-2-1-base "
variant = " original "
variant = " original "
# variant = "split_einsum"
def download_model ( repo_id , variant , output_dir ) :
def download_model ( repo_id , variant , output_dir ) :
destination = Path ( output_dir ) / ( repo_id . split ( " / " ) [ - 1 ] + " _ " + variant . replace ( " / " , " _ " ) )
destination = Path ( output_dir ) / (
repo_id . split ( " / " ) [ - 1 ] + " _ " + variant . replace ( " / " , " _ " )
)
if destination . exists ( ) :
if destination . exists ( ) :
raise Exception ( f " Model already exists at { destination } " )
raise Exception ( f " Model already exists at { destination } " )
# Download and copy without symlinks
# Download and copy without symlinks
downloaded = snapshot_download ( repo_id , allow_patterns = f " { variant } /* " , cache_dir = output_dir )
downloaded = snapshot_download (
repo_id , allow_patterns = f " { variant } /* " , cache_dir = output_dir
)
downloaded_bundle = Path ( downloaded ) / variant
downloaded_bundle = Path ( downloaded ) / variant
shutil . copytree ( downloaded_bundle , destination )
shutil . copytree ( downloaded_bundle , destination )
# Remove all downloaded files
# Remove all downloaded files
cache_folder = Path ( output_dir ) / repo_folder_name ( repo_id = repo_id , repo_type = " model " )
cache_folder = Path ( output_dir ) / repo_folder_name (
repo_id = repo_id , repo_type = " model "
)
shutil . rmtree ( cache_folder )
shutil . rmtree ( cache_folder )
return destination
return destination