|
|
|
#
|
|
|
|
# For licensing see accompanying LICENSE.md file.
|
|
|
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
|
|
#
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
from diffusers.pipeline_utils import DiffusionPipeline
|
|
|
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
|
|
from diffusers.schedulers import (
|
|
|
|
DDIMScheduler,
|
|
|
|
DPMSolverMultistepScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
|
|
|
EulerDiscreteScheduler,
|
|
|
|
LMSDiscreteScheduler,
|
|
|
|
PNDMScheduler,
|
|
|
|
)
|
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
|
|
|
|
|
|
import gc
|
|
|
|
import inspect
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
logging.basicConfig()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
|
|
|
|
from python_coreml_stable_diffusion.coreml_model import (
|
|
|
|
CoreMLModel,
|
|
|
|
_load_mlpackage,
|
|
|
|
get_available_compute_units,
|
|
|
|
)
|
|
|
|
|
|
|
|
import time
|
|
|
|
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
|
|
|
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
|
|
|
""" Core ML version of
|
|
|
|
`diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline`
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
text_encoder: CoreMLModel,
|
|
|
|
unet: CoreMLModel,
|
|
|
|
vae_decoder: CoreMLModel,
|
|
|
|
feature_extractor: CLIPFeatureExtractor,
|
|
|
|
safety_checker: Optional[CoreMLModel],
|
|
|
|
scheduler: Union[DDIMScheduler,
|
|
|
|
DPMSolverMultistepScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
|
|
|
EulerDiscreteScheduler,
|
|
|
|
LMSDiscreteScheduler,
|
|
|
|
PNDMScheduler],
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Register non-Core ML components of the pipeline similar to the original pipeline
|
|
|
|
self.register_modules(
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
scheduler=scheduler,
|
|
|
|
feature_extractor=feature_extractor,
|
|
|
|
)
|
|
|
|
|
|
|
|
if safety_checker is None:
|
|
|
|
# Reproduce original warning:
|
|
|
|
# https://github.com/huggingface/diffusers/blob/v0.9.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L119
|
|
|
|
logger.warning(
|
|
|
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
|
|
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
|
|
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
|
|
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
|
|
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
|
|
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
|
|
|
)
|
|
|
|
|
|
|
|
# Register Core ML components of the pipeline
|
|
|
|
self.safety_checker = safety_checker
|
|
|
|
self.text_encoder = text_encoder
|
|
|
|
self.unet = unet
|
|
|
|
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1]
|
|
|
|
|
|
|
|
self.vae_decoder = vae_decoder
|
|
|
|
|
|
|
|
VAE_DECODER_UPSAMPLE_FACTOR = 8
|
|
|
|
|
|
|
|
# In PyTorch, users can determine the tensor shapes dynamically by default
|
|
|
|
# In CoreML, tensors have static shapes unless flexible shapes were used during export
|
|
|
|
# See https://coremltools.readme.io/docs/flexible-inputs
|
|
|
|
latent_h, latent_w = self.unet.expected_inputs["sample"]["shape"][2:]
|
|
|
|
self.height = latent_h * VAE_DECODER_UPSAMPLE_FACTOR
|
|
|
|
self.width = latent_w * VAE_DECODER_UPSAMPLE_FACTOR
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Stable Diffusion configured to generate {self.height}x{self.width} images"
|
|
|
|
)
|
|
|
|
|
|
|
|
def _encode_prompt(self, prompt, num_images_per_prompt,
|
|
|
|
do_classifier_free_guidance, negative_prompt):
|
|
|
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
|
|
|
|
|
|
|
text_inputs = self.tokenizer(
|
|
|
|
prompt,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=self.tokenizer.model_max_length,
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
text_input_ids = text_inputs.input_ids
|
|
|
|
|
|
|
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
|
|
|
removed_text = self.tokenizer.batch_decode(
|
|
|
|
text_input_ids[:, self.tokenizer.model_max_length:])
|
|
|
|
logger.warning(
|
|
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
|
|
|
|
text_input_ids = text_input_ids[:, :self.tokenizer.
|
|
|
|
model_max_length]
|
|
|
|
|
|
|
|
text_embeddings = self.text_encoder(
|
|
|
|
input_ids=text_input_ids.astype(np.float32))["last_hidden_state"]
|
|
|
|
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
uncond_tokens: List[str]
|
|
|
|
if negative_prompt is None:
|
|
|
|
uncond_tokens = [""] * batch_size
|
|
|
|
elif type(prompt) is not type(negative_prompt):
|
|
|
|
raise TypeError(
|
|
|
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
|
|
|
" {type(prompt)}.")
|
|
|
|
elif isinstance(negative_prompt, str):
|
|
|
|
uncond_tokens = [negative_prompt] * batch_size
|
|
|
|
elif batch_size != len(negative_prompt):
|
|
|
|
raise ValueError(
|
|
|
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
|
|
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
|
|
|
" the batch size of `prompt`.")
|
|
|
|
else:
|
|
|
|
uncond_tokens = negative_prompt
|
|
|
|
|
|
|
|
max_length = text_input_ids.shape[-1]
|
|
|
|
uncond_input = self.tokenizer(
|
|
|
|
uncond_tokens,
|
|
|
|
padding="max_length",
|
|
|
|
max_length=max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
|
|
|
|
uncond_embeddings = self.text_encoder(
|
|
|
|
input_ids=uncond_input.input_ids.astype(
|
|
|
|
np.float32))["last_hidden_state"]
|
|
|
|
|
|
|
|
# For classifier free guidance, we need to do two forward passes.
|
|
|
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
|
# to avoid doing two forward passes
|
|
|
|
text_embeddings = np.concatenate(
|
|
|
|
[uncond_embeddings, text_embeddings])
|
|
|
|
|
|
|
|
text_embeddings = text_embeddings.transpose(0, 2, 1)[:, :, None, :]
|
|
|
|
|
|
|
|
return text_embeddings
|
|
|
|
|
|
|
|
def run_safety_checker(self, image):
|
|
|
|
if self.safety_checker is not None:
|
|
|
|
safety_checker_input = self.feature_extractor(
|
|
|
|
self.numpy_to_pil(image),
|
|
|
|
return_tensors="np",
|
|
|
|
)
|
|
|
|
|
|
|
|
safety_checker_outputs = self.safety_checker(
|
|
|
|
clip_input=safety_checker_input.pixel_values.astype(
|
|
|
|
np.float16),
|
|
|
|
images=image.astype(np.float16),
|
|
|
|
adjustment=np.array([0.]).astype(
|
|
|
|
np.float16), # defaults to 0 in original pipeline
|
|
|
|
)
|
|
|
|
|
|
|
|
# Unpack dict
|
|
|
|
has_nsfw_concept = safety_checker_outputs["has_nsfw_concepts"]
|
|
|
|
image = safety_checker_outputs["filtered_images"]
|
|
|
|
concept_scores = safety_checker_outputs["concept_scores"]
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
f"Generated image has nsfw concept={has_nsfw_concept.any()}")
|
|
|
|
else:
|
|
|
|
has_nsfw_concept = None
|
|
|
|
|
|
|
|
return image, has_nsfw_concept
|
|
|
|
|
|
|
|
def decode_latents(self, latents):
|
|
|
|
latents = 1 / 0.18215 * latents
|
|
|
|
image = self.vae_decoder(z=latents.astype(np.float16))["image"]
|
|
|
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
|
|
image = image.transpose((0, 2, 3, 1))
|
|
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
def prepare_latents(self,
|
|
|
|
batch_size,
|
|
|
|
num_channels_latents,
|
|
|
|
height,
|
|
|
|
width,
|
|
|
|
latents=None):
|
|
|
|
latents_shape = (batch_size, num_channels_latents, self.height // 8,
|
|
|
|
self.width // 8)
|
|
|
|
if latents is None:
|
|
|
|
latents = np.random.randn(*latents_shape).astype(np.float16)
|
|
|
|
elif latents.shape != latents_shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
|
|
|
|
)
|
|
|
|
|
|
|
|
latents = latents * self.scheduler.init_noise_sigma
|
|
|
|
|
|
|
|
return latents
|
|
|
|
|
|
|
|
def check_inputs(self, prompt, height, width, callback_steps):
|
|
|
|
if height != self.height or width != self.width:
|
|
|
|
logger.warning(
|
|
|
|
"`height` and `width` dimensions (of the output image tensor) are fixed when exporting the Core ML models " \
|
|
|
|
"unless flexible shapes are used during export (https://coremltools.readme.io/docs/flexible-inputs). " \
|
|
|
|
"This pipeline was provided with Core ML models that generate {self.height}x{self.width} images (user requested {height}x{width})"
|
|
|
|
)
|
|
|
|
|
|
|
|
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
|
|
|
raise ValueError(
|
|
|
|
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if height % 8 != 0 or width % 8 != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
|
|
|
)
|
|
|
|
|
|
|
|
if (callback_steps is None) or (callback_steps is not None and
|
|
|
|
(not isinstance(callback_steps, int)
|
|
|
|
or callback_steps <= 0)):
|
|
|
|
raise ValueError(
|
|
|
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
|
|
|
f" {type(callback_steps)}.")
|
|
|
|
|
|
|
|
def prepare_extra_step_kwargs(self, eta):
|
|
|
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
|
|
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
|
|
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
|
|
|
# and should be between [0, 1]
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(
|
|
|
|
inspect.signature(self.scheduler.step).parameters.keys())
|
|
|
|
extra_step_kwargs = {}
|
|
|
|
if accepts_eta:
|
|
|
|
extra_step_kwargs["eta"] = eta
|
|
|
|
|
|
|
|
return extra_step_kwargs
|
|
|
|
|
|
|
|
def __call__(
|
|
|
|
self,
|
|
|
|
prompt,
|
|
|
|
height=512,
|
|
|
|
width=512,
|
|
|
|
num_inference_steps=50,
|
|
|
|
guidance_scale=7.5,
|
|
|
|
negative_prompt=None,
|
|
|
|
num_images_per_prompt=1,
|
|
|
|
eta=0.0,
|
|
|
|
latents=None,
|
|
|
|
output_type="pil",
|
|
|
|
return_dict=True,
|
|
|
|
callback=None,
|
|
|
|
callback_steps=1,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
# 1. Check inputs. Raise error if not correct
|
|
|
|
self.check_inputs(prompt, height, width, callback_steps)
|
|
|
|
|
|
|
|
# 2. Define call parameters
|
|
|
|
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
|
|
|
if batch_size > 1 or num_images_per_prompt > 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"For batched generation of multiple images and/or multiple prompts, please refer to the Swift package."
|
|
|
|
)
|
|
|
|
|
|
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
|
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
|
|
# corresponds to doing no classifier free guidance.
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0
|
|
|
|
|
|
|
|
# 3. Encode input prompt
|
|
|
|
text_embeddings = self._encode_prompt(
|
|
|
|
prompt,
|
|
|
|
num_images_per_prompt,
|
|
|
|
do_classifier_free_guidance,
|
|
|
|
negative_prompt,
|
|
|
|
)
|
|
|
|
|
|
|
|
# 4. Prepare timesteps
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
|
timesteps = self.scheduler.timesteps
|
|
|
|
|
|
|
|
# 5. Prepare latent variables
|
|
|
|
num_channels_latents = self.unet.in_channels
|
|
|
|
latents = self.prepare_latents(
|
|
|
|
batch_size * num_images_per_prompt,
|
|
|
|
num_channels_latents,
|
|
|
|
height,
|
|
|
|
width,
|
|
|
|
latents,
|
|
|
|
)
|
|
|
|
|
|
|
|
# 6. Prepare extra step kwargs
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
|
|
|
|
|
|
|
|
# 7. Denoising loop
|
|
|
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
|
|
|
# expand the latents if we are doing classifier free guidance
|
|
|
|
latent_model_input = np.concatenate(
|
|
|
|
[latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(
|
|
|
|
latent_model_input, t)
|
|
|
|
|
|
|
|
# predict the noise residual
|
|
|
|
noise_pred = self.unet(
|
|
|
|
sample=latent_model_input.astype(np.float16),
|
|
|
|
timestep=np.array([t, t], np.float16),
|
|
|
|
encoder_hidden_states=text_embeddings.astype(np.float16),
|
|
|
|
)["noise_pred"]
|
|
|
|
|
|
|
|
# perform guidance
|
|
|
|
if do_classifier_free_guidance:
|
|
|
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
|
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
|
|
noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
|
|
latents = self.scheduler.step(torch.from_numpy(noise_pred),
|
|
|
|
t,
|
|
|
|
torch.from_numpy(latents),
|
|
|
|
**extra_step_kwargs,
|
|
|
|
).prev_sample.numpy()
|
|
|
|
|
|
|
|
# call the callback, if provided
|
|
|
|
if callback is not None and i % callback_steps == 0:
|
|
|
|
callback(i, t, latents)
|
|
|
|
|
|
|
|
# 8. Post-processing
|
|
|
|
image = self.decode_latents(latents)
|
|
|
|
|
|
|
|
# 9. Run safety checker
|
|
|
|
image, has_nsfw_concept = self.run_safety_checker(image)
|
|
|
|
|
|
|
|
# 10. Convert to PIL
|
|
|
|
if output_type == "pil":
|
|
|
|
image = self.numpy_to_pil(image)
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
return (image, has_nsfw_concept)
|
|
|
|
|
|
|
|
return StableDiffusionPipelineOutput(
|
|
|
|
images=image, nsfw_content_detected=has_nsfw_concept)
|
|
|
|
|
|
|
|
|
|
|
|
def get_available_schedulers():
|
|
|
|
schedulers = {}
|
|
|
|
for scheduler in [DDIMScheduler,
|
|
|
|
DPMSolverMultistepScheduler,
|
|
|
|
EulerAncestralDiscreteScheduler,
|
|
|
|
EulerDiscreteScheduler,
|
|
|
|
LMSDiscreteScheduler,
|
|
|
|
PNDMScheduler]:
|
|
|
|
schedulers[scheduler().__class__.__name__.replace("Scheduler", "")] = scheduler
|
|
|
|
return schedulers
|
|
|
|
|
|
|
|
SCHEDULER_MAP = get_available_schedulers()
|
|
|
|
|
|
|
|
def get_coreml_pipe(pytorch_pipe,
|
|
|
|
mlpackages_dir,
|
|
|
|
model_version,
|
|
|
|
compute_unit,
|
|
|
|
delete_original_pipe=True,
|
|
|
|
scheduler_override=None):
|
|
|
|
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original
|
|
|
|
diffusers PyTorch pipeline
|
|
|
|
"""
|
|
|
|
# Ensure `scheduler_override` object is of correct type if specified
|
|
|
|
if scheduler_override is not None:
|
|
|
|
assert isinstance(scheduler_override, SchedulerMixin)
|
|
|
|
logger.warning(
|
|
|
|
"Overriding scheduler in pipeline: "
|
|
|
|
f"Default={pytorch_pipe.scheduler}, Override={scheduler_override}")
|
|
|
|
|
|
|
|
# Gather configured tokenizer and scheduler attributes from the original pipe
|
|
|
|
coreml_pipe_kwargs = {
|
|
|
|
"tokenizer": pytorch_pipe.tokenizer,
|
|
|
|
"scheduler": pytorch_pipe.scheduler if scheduler_override is None else scheduler_override,
|
|
|
|
"feature_extractor": pytorch_pipe.feature_extractor,
|
|
|
|
}
|
|
|
|
|
|
|
|
model_names_to_load = ["text_encoder", "unet", "vae_decoder"]
|
|
|
|
if getattr(pytorch_pipe, "safety_checker", None) is not None:
|
|
|
|
model_names_to_load.append("safety_checker")
|
|
|
|
else:
|
|
|
|
logger.warning(
|
|
|
|
f"Original diffusers pipeline for {model_version} does not have a safety_checker, "
|
|
|
|
"Core ML pipeline will mirror this behavior.")
|
|
|
|
coreml_pipe_kwargs["safety_checker"] = None
|
|
|
|
|
|
|
|
if delete_original_pipe:
|
|
|
|
del pytorch_pipe
|
|
|
|
gc.collect()
|
|
|
|
logger.info("Removed PyTorch pipe to reduce peak memory consumption")
|
|
|
|
|
|
|
|
# Load Core ML models
|
|
|
|
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}")
|
|
|
|
coreml_pipe_kwargs.update({
|
|
|
|
model_name: _load_mlpackage(
|
|
|
|
model_name,
|
|
|
|
mlpackages_dir,
|
|
|
|
model_version,
|
|
|
|
compute_unit,
|
|
|
|
)
|
|
|
|
for model_name in model_names_to_load
|
|
|
|
})
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
|
|
logger.info("Initializing Core ML pipe for image generation")
|
|
|
|
coreml_pipe = CoreMLStableDiffusionPipeline(**coreml_pipe_kwargs)
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
|
|
return coreml_pipe
|
|
|
|
|
|
|
|
|
|
|
|
def get_image_path(args, **override_kwargs):
|
|
|
|
""" mkdir output folder and encode metadata in the filename
|
|
|
|
"""
|
|
|
|
out_folder = os.path.join(args.o, "_".join(args.prompt.replace("/", "_").rsplit(" ")))
|
|
|
|
os.makedirs(out_folder, exist_ok=True)
|
|
|
|
|
|
|
|
out_fname = f"randomSeed_{override_kwargs.get('seed', None) or args.seed}"
|
|
|
|
out_fname += f"_computeUnit_{override_kwargs.get('compute_unit', None) or args.compute_unit}"
|
|
|
|
out_fname += f"_modelVersion_{override_kwargs.get('model_version', None) or args.model_version.replace('/', '_')}"
|
|
|
|
|
|
|
|
if args.scheduler is not None:
|
|
|
|
out_fname += f"_customScheduler_{override_kwargs.get('scheduler', None) or args.scheduler}"
|
|
|
|
out_fname += f"_numInferenceSteps{override_kwargs.get('num_inference_steps', None) or args.num_inference_steps}"
|
|
|
|
|
|
|
|
return os.path.join(out_folder, out_fname + ".png")
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
logger.info(f"Setting random seed to {args.seed}")
|
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
|
|
|
logger.info("Initializing PyTorch pipe for reference configuration")
|
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version,
|
|
|
|
use_auth_token=True)
|
|
|
|
|
|
|
|
user_specified_scheduler = None
|
|
|
|
if args.scheduler is not None:
|
|
|
|
user_specified_scheduler = SCHEDULER_MAP[
|
|
|
|
args.scheduler].from_config(pytorch_pipe.scheduler.config)
|
|
|
|
|
|
|
|
coreml_pipe = get_coreml_pipe(pytorch_pipe=pytorch_pipe,
|
|
|
|
mlpackages_dir=args.i,
|
|
|
|
model_version=args.model_version,
|
|
|
|
compute_unit=args.compute_unit,
|
|
|
|
scheduler_override=user_specified_scheduler)
|
|
|
|
|
|
|
|
logger.info("Beginning image generation.")
|
|
|
|
image = coreml_pipe(
|
|
|
|
prompt=args.prompt,
|
|
|
|
height=coreml_pipe.height,
|
|
|
|
width=coreml_pipe.width,
|
|
|
|
num_inference_steps=args.num_inference_steps,
|
|
|
|
)
|
|
|
|
|
|
|
|
out_path = get_image_path(args)
|
|
|
|
logger.info(f"Saving generated image to {out_path}")
|
|
|
|
image["images"][0].save(out_path)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
"--prompt",
|
|
|
|
required=True,
|
|
|
|
help="The text prompt to be used for text-to-image generation.")
|
|
|
|
parser.add_argument(
|
|
|
|
"-i",
|
|
|
|
required=True,
|
|
|
|
help=("Path to input directory with the .mlpackage files generated by "
|
|
|
|
"python_coreml_stable_diffusion.torch2coreml"))
|
|
|
|
parser.add_argument("-o", required=True)
|
|
|
|
parser.add_argument("--seed",
|
|
|
|
"-s",
|
|
|
|
default=93,
|
|
|
|
type=int,
|
|
|
|
help="Random seed to be able to reproduce results")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model-version",
|
|
|
|
default="CompVis/stable-diffusion-v1-4",
|
|
|
|
help=
|
|
|
|
("The pre-trained model checkpoint and configuration to restore. "
|
|
|
|
"For available versions: https://huggingface.co/models?search=stable-diffusion"
|
|
|
|
))
|
|
|
|
parser.add_argument(
|
|
|
|
"--compute-unit",
|
|
|
|
choices=get_available_compute_units(),
|
|
|
|
default="ALL",
|
|
|
|
help=("The compute units to be used when executing Core ML models. "
|
|
|
|
f"Options: {get_available_compute_units()}"))
|
|
|
|
parser.add_argument(
|
|
|
|
"--scheduler",
|
|
|
|
choices=tuple(SCHEDULER_MAP.keys()),
|
|
|
|
default=None,
|
|
|
|
help=("The scheduler to use for running the reverse diffusion process. "
|
|
|
|
"If not specified, the default scheduler from the diffusers pipeline is utilized"))
|
|
|
|
parser.add_argument(
|
|
|
|
"--num-inference-steps",
|
|
|
|
default=50,
|
|
|
|
type=int,
|
|
|
|
help="The number of iterations the unet model will be executed throughout the reverse diffusion process")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|