diff --git a/README.md b/README.md
index f108f36..6d506df 100644
--- a/README.md
+++ b/README.md
@@ -220,6 +220,10 @@ Both of these products require the Core ML models and tokenization resources to
- `vocab.json` (tokenizer vocabulary file)
- `merges.text` (merges for byte pair encoding file)
+Optionally, for image2image, in-painting, or similar:
+
+- `VAEEncoder.mlmodelc` (image encoder model)
+
Optionally, it may also include the safety checker model that some versions of Stable Diffusion include:
- `SafetyChecker.mlmodelc`
@@ -321,6 +325,7 @@ Differences may be less or more pronounced for different inputs. Please see the
A3: In order to minimize the memory impact of the model conversion process, please execute the following command instead:
```bash
+python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-encoder -o && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-decoder -o && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet -o && \
python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder -o && \
diff --git a/python_coreml_stable_diffusion/torch2coreml.py b/python_coreml_stable_diffusion/torch2coreml.py
index 6d6c2fa..84e201c 100644
--- a/python_coreml_stable_diffusion/torch2coreml.py
+++ b/python_coreml_stable_diffusion/torch2coreml.py
@@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
) for k, v in sample_inputs.items()
]
+# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
+# as implemented in vae.py as part of the AutoencoderKL class
+# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
+# coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
+class CoreMLDiagonalGaussianDistribution(object):
+ def __init__(self, parameters, noise):
+ self.parameters = parameters
+ self.noise = noise
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.std = torch.exp(0.5 * self.logvar)
+
+ def sample(self) -> torch.FloatTensor:
+ x = self.mean + self.std * self.noise
+ return x
def compute_psnr(a, b):
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
@@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
def quantize_weights_to_8bits(args):
for model_name in [
- "text_encoder", "vae_decoder", "unet", "unet_chunk1",
+ "text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
"unet_chunk2", "safety_checker"
]:
out_path = _get_out_path(args, model_name)
@@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name, target_name in [("text_encoder", "TextEncoder"),
("vae_decoder", "VAEDecoder"),
+ ("vae_encoder", "VAEEncoder"),
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
@@ -453,6 +469,159 @@ def convert_vae_decoder(pipe, args):
gc.collect()
+def convert_vae_encoder(pipe, args):
+ """ Converts the VAE Encoder component of Stable Diffusion
+ """
+ out_path = _get_out_path(args, "vae_encoder")
+ if os.path.exists(out_path):
+ logger.info(
+ f"`vae_encoder` already exists at {out_path}, skipping conversion."
+ )
+ return
+
+ if not hasattr(pipe, "unet"):
+ raise RuntimeError(
+ "convert_unet() deletes pipe.unet to save RAM. "
+ "Please use convert_vae_encoder() before convert_unet()")
+
+ sample_shape = (
+ 1, # B
+ 3, # C (RGB range from -1 to 1)
+ (args.latent_h or pipe.unet.config.sample_size) * 8, # H
+ (args.latent_w or pipe.unet.config.sample_size) * 8, # w
+ )
+
+ noise_shape = (
+ 1, # B
+ 4, # C
+ pipe.unet.config.sample_size, # H
+ pipe.unet.config.sample_size, # w
+ )
+
+ float_value_shape = (
+ 1,
+ 1,
+ )
+
+ sqrt_alphas_cumprod_torch_shape = torch.tensor([[0.2,]])
+ sqrt_one_minus_alphas_cumprod_torch_shape = torch.tensor([[0.8,]])
+
+ sample_vae_encoder_inputs = {
+ "sample": torch.rand(*sample_shape, dtype=torch.float16),
+ "diagonal_noise": torch.rand(*noise_shape, dtype=torch.float16),
+ "noise": torch.rand(*noise_shape, dtype=torch.float16),
+ "sqrt_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
+ "sqrt_one_minus_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
+ }
+
+ class VAEEncoder(nn.Module):
+ """ Wrapper nn.Module wrapper for pipe.encode() method
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.quant_conv = pipe.vae.quant_conv
+ self.alphas_cumprod = pipe.scheduler.alphas_cumprod
+ self.encoder = pipe.vae.encoder
+
+ # Because CoreMLTools does not support the torch.randn op, we pass in both
+ # the diagonal Noise for the `DiagonalGaussianDistribution` operation and
+ # the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
+ # for faster computation.
+ def forward(self, sample, diagonal_noise, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
+ h = self.encoder(sample)
+ moments = self.quant_conv(h)
+ posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
+ posteriorSample = posterior.sample()
+
+ # Add the scaling operation and the latent noise for faster computation
+ init_latents = 0.18215 * posteriorSample
+ result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
+ return result
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ sqrt_alphas_cumprod: torch.FloatTensor,
+ sqrt_one_minus_alphas_cumprod: torch.FloatTensor
+ ) -> torch.FloatTensor:
+ noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
+ return noisy_samples
+
+
+ baseline_encoder = VAEEncoder().eval()
+
+ # No optimization needed for the VAE Encoder as it is a pure ConvNet
+ traced_vae_encoder = torch.jit.trace(
+ baseline_encoder, (
+ sample_vae_encoder_inputs["sample"].to(torch.float32),
+ sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
+ sample_vae_encoder_inputs["noise"].to(torch.float32),
+ sqrt_alphas_cumprod_torch_shape.to(torch.float32),
+ sqrt_one_minus_alphas_cumprod_torch_shape.to(torch.float32)
+ ))
+
+ modify_coremltools_torch_frontend_badbmm()
+ coreml_vae_encoder, out_path = _convert_to_coreml(
+ "vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
+ ["latent_dist"], args)
+
+ # Set model metadata
+ coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
+ coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
+ coreml_vae_encoder.version = args.model_version
+ coreml_vae_encoder.short_description = \
+ "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
+ "Please refer to https://arxiv.org/abs/2112.10752 for details."
+
+ # Set the input descriptions
+ coreml_vae_encoder.input_description["sample"] = \
+ "An image of the correct size to create the latent space with, image2image and in-painting."
+ coreml_vae_encoder.input_description["diagonal_noise"] = \
+ "Latent noise for `DiagonalGaussianDistribution` operation."
+ coreml_vae_encoder.input_description["noise"] = \
+ "Latent noise for use with strength parameter of image2image"
+ coreml_vae_encoder.input_description["sqrt_alphas_cumprod"] = \
+ "Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
+ coreml_vae_encoder.input_description["sqrt_one_minus_alphas_cumprod"] = \
+ "Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
+
+ # Set the output descriptions
+ coreml_vae_encoder.output_description[
+ "latent_dist"] = "The latent embeddings from the unet model from the input image."
+
+ _save_mlpackage(coreml_vae_encoder, out_path)
+
+ logger.info(f"Saved vae_encoder into {out_path}")
+
+ # Parity check PyTorch vs CoreML
+ if args.check_output_correctness:
+ baseline_out = baseline_encoder(
+ sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
+ diagonal_noise=sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
+ noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod_torch_shape,
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod_torch_shape,
+ ).numpy(),
+
+ coreml_out = list(
+ coreml_vae_encoder.predict(
+ {
+ "sample": sample_vae_encoder_inputs["sample"].numpy(),
+ "diagonal_noise": sample_vae_encoder_inputs["diagonal_noise"].numpy(),
+ "noise": sample_vae_encoder_inputs["noise"].numpy(),
+ "sqrt_alphas_cumprod": sqrt_alphas_cumprod_torch_shape.numpy(),
+ "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod_torch_shape.numpy()
+ }).values())
+
+ report_correctness(baseline_out[0], coreml_out[0],
+ "vae_encoder baseline PyTorch to baseline CoreML")
+
+ del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
+ gc.collect()
+
+
def convert_unet(pipe, args):
""" Converts the UNet component of Stable Diffusion
"""
@@ -801,7 +970,12 @@ def main(args):
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")
-
+
+ if args.convert_vae_encoder:
+ logger.info("Converting vae_encoder")
+ convert_vae_encoder(pipe, args)
+ logger.info("Converted vae_encoder")
+
if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)
@@ -835,6 +1009,7 @@ def parser_spec():
# Select which models to export (All are needed for text-to-image pipeline to function)
parser.add_argument("--convert-text-encoder", action="store_true")
parser.add_argument("--convert-vae-decoder", action="store_true")
+ parser.add_argument("--convert-vae-encoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(