You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

927 lines
36 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from python_coreml_stable_diffusion import unet
import argparse
from collections import OrderedDict, defaultdict
from copy import deepcopy
import coremltools as ct
from diffusers import StableDiffusionPipeline
import gc
import logging
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import numpy as np
import os
from python_coreml_stable_diffusion import chunk_mlprogram
import requests
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_grad_enabled(False)
from types import MethodType
def _get_coreml_inputs(sample_inputs, args):
return [
ct.TensorType(
name=k,
shape=v.shape,
dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype,
) for k, v in sample_inputs.items()
]
def compute_psnr(a, b):
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
"""
max_b = np.abs(b).max()
sumdeltasq = 0.0
sumdeltasq = ((a - b) * (a - b)).sum()
sumdeltasq /= b.size
sumdeltasq = np.sqrt(sumdeltasq)
eps = 1e-5
eps2 = 1e-10
psnr = 20 * np.log10((max_b + eps) / (sumdeltasq + eps2))
return psnr
ABSOLUTE_MIN_PSNR = 35
def report_correctness(original_outputs, final_outputs, log_prefix):
""" Report PSNR values across two compatible tensors
"""
original_psnr = compute_psnr(original_outputs, original_outputs)
final_psnr = compute_psnr(original_outputs, final_outputs)
dB_change = final_psnr - original_psnr
logger.info(
f"{log_prefix}: PSNR changed by {dB_change:.1f} dB ({original_psnr:.1f} -> {final_psnr:.1f})"
)
if final_psnr < ABSOLUTE_MIN_PSNR:
raise ValueError(f"{final_psnr:.1f} dB is too low!")
else:
logger.info(
f"{final_psnr:.1f} dB > {ABSOLUTE_MIN_PSNR} dB (minimum allowed) parity check passed"
)
return final_psnr
def _get_out_path(args, submodule_name):
fname = f"Stable_Diffusion_version_{args.model_version}_{submodule_name}.mlpackage"
fname = fname.replace("/", "_")
return os.path.join(args.o, fname)
# https://github.com/apple/coremltools/issues/1680
def _save_mlpackage(model, output_path):
# First recreate MLModel object using its in memory spec, then save
ct.models.MLModel(model._spec,
weights_dir=model._weights_dir,
is_temp_package=True).save(output_path)
def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
output_names, args):
out_path = _get_out_path(args, submodule_name)
if os.path.exists(out_path):
logger.info(f"Skipping export because {out_path} already exists")
logger.info(f"Loading model from {out_path}")
start = time.time()
# Note: Note that each model load will trigger a model compilation which takes up to a few minutes.
# The Swifty CLI we provide uses precompiled Core ML models (.mlmodelc) which incurs compilation only
# upon first load and mitigates the load time in subsequent runs.
coreml_model = ct.models.MLModel(
out_path, compute_units=ct.ComputeUnit[args.compute_unit])
logger.info(
f"Loading {out_path} took {time.time() - start:.1f} seconds")
coreml_model.compute_unit = ct.ComputeUnit[args.compute_unit]
else:
logger.info(f"Converting {submodule_name} to CoreML..")
coreml_model = ct.convert(
torchscript_module,
convert_to="mlprogram",
minimum_deployment_target=ct.target.macOS13,
inputs=_get_coreml_inputs(sample_inputs, args),
outputs=[ct.TensorType(name=name) for name in output_names],
compute_units=ct.ComputeUnit[args.compute_unit],
# skip_model_load=True,
)
del torchscript_module
gc.collect()
coreml_model.save(out_path)
logger.info(f"Saved {submodule_name} model to {out_path}")
return coreml_model, out_path
def quantize_weights_to_8bits(args):
for model_name in [
"text_encoder", "vae_decoder", "unet", "unet_chunk1",
"unet_chunk2", "safety_checker"
]:
out_path = _get_out_path(args, model_name)
if os.path.exists(out_path):
logger.info(f"Quantizing {model_name}")
mlmodel = ct.models.MLModel(out_path,
compute_units=ct.ComputeUnit.CPU_ONLY)
mlmodel = ct.compression_utils.affine_quantize_weights(
mlmodel, mode="linear")
mlmodel.save(out_path)
logger.info("Done")
else:
logger.info(
f"Skipped quantizing {model_name} (Not found at {out_path})")
def _compile_coreml_model(source_model_path, output_dir, final_name):
""" Compiles Core ML models using the coremlcompiler utility from Xcode toolchain
"""
target_path = os.path.join(output_dir, f"{final_name}.mlmodelc")
if os.path.exists(target_path):
logger.warning(
f"Found existing compiled model at {target_path}! Skipping..")
return target_path
logger.info(f"Compiling {source_model_path}")
source_model_name = os.path.basename(
os.path.splitext(source_model_path)[0])
os.system(f"xcrun coremlcompiler compile {source_model_path} {output_dir}")
compiled_output = os.path.join(output_dir, f"{source_model_name}.mlmodelc")
shutil.move(compiled_output, target_path)
return target_path
def bundle_resources_for_swift_cli(args):
"""
- Compiles Core ML models from mlpackage into mlmodelc format
- Download tokenizer resources for the text encoder
"""
resources_dir = os.path.join(args.o, "Resources")
if not os.path.exists(resources_dir):
os.makedirs(resources_dir, exist_ok=True)
logger.info(f"Created {resources_dir} for Swift CLI assets")
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name, target_name in [("text_encoder", "TextEncoder"),
("vae_decoder", "VAEDecoder"),
("unet", "Unet"),
("unet_chunk1", "UnetChunk1"),
("unet_chunk2", "UnetChunk2"),
("safety_checker", "SafetyChecker")]:
source_path = _get_out_path(args, source_name)
if os.path.exists(source_path):
target_path = _compile_coreml_model(source_path, resources_dir,
target_name)
logger.info(f"Compiled {source_path} to {target_path}")
else:
logger.warning(
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
)
# Fetch and save vocabulary JSON file for text tokenizer
logger.info("Downloading and saving tokenizer vocab.json")
with open(os.path.join(resources_dir, "vocab.json"), "wb") as f:
f.write(requests.get(args.text_encoder_vocabulary_url).content)
logger.info("Done")
# Fetch and save merged pairs JSON file for text tokenizer
logger.info("Downloading and saving tokenizer merges.txt")
with open(os.path.join(resources_dir, "merges.txt"), "wb") as f:
f.write(requests.get(args.text_encoder_merges_url).content)
logger.info("Done")
return resources_dir
def convert_text_encoder(pipe, args):
""" Converts the text encoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "text_encoder")
if os.path.exists(out_path):
logger.info(
f"`text_encoder` already exists at {out_path}, skipping conversion."
)
return
# Create sample inputs for tracing, conversion and correctness verification
text_encoder_sequence_length = pipe.tokenizer.model_max_length
text_encoder_hidden_size = pipe.text_encoder.config.hidden_size
sample_text_encoder_inputs = {
"input_ids":
torch.randint(
pipe.text_encoder.config.vocab_size,
(1, text_encoder_sequence_length),
# https://github.com/apple/coremltools/issues/1423
dtype=torch.float32,
)
}
sample_text_encoder_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_text_encoder_inputs.items()
}
logger.info(f"Sample inputs spec: {sample_text_encoder_inputs_spec}")
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
mask = torch.ones((bsz, seq_len, seq_len), dtype=dtype) * -1e4
mask.triu_(1)
mask = mask.unsqueeze(1)
return mask
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = pipe.text_encoder
setattr(
self.text_encoder.text_model, "_build_causal_attention_mask",
MethodType(_build_causal_attention_mask,
self.text_encoder.text_model))
def forward(self, input_ids):
return self.text_encoder(input_ids, return_dict=False)
reference_text_encoder = TextEncoder().eval()
logger.info("JIT tracing text_encoder..")
reference_text_encoder = torch.jit.trace(
reference_text_encoder,
(sample_text_encoder_inputs["input_ids"].to(torch.int32), ),
)
logger.info("Done.")
coreml_text_encoder, out_path = _convert_to_coreml(
"text_encoder", reference_text_encoder, sample_text_encoder_inputs,
["last_hidden_state", "pooled_outputs"], args)
# Set model metadata
coreml_text_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_text_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_text_encoder.version = args.model_version
coreml_text_encoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_text_encoder.input_description[
"input_ids"] = "The token ids that represent the input text"
# Set the output descriptions
coreml_text_encoder.output_description[
"last_hidden_state"] = "The token embeddings as encoded by the Transformer model"
coreml_text_encoder.output_description[
"pooled_outputs"] = "The version of the `last_hidden_state` output after pooling"
_save_mlpackage(coreml_text_encoder, out_path)
logger.info(f"Saved text_encoder into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = pipe.text_encoder(
sample_text_encoder_inputs["input_ids"].to(torch.int32),
return_dict=False,
)[1].numpy()
coreml_out = list(
coreml_text_encoder.predict(
{k: v.numpy()
for k, v in sample_text_encoder_inputs.items()}).values())[0]
report_correctness(
baseline_out, coreml_out,
"text_encoder baseline PyTorch to reference CoreML")
del reference_text_encoder, coreml_text_encoder, pipe.text_encoder
gc.collect()
def modify_coremltools_torch_frontend_badbmm():
"""
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
"""
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY
if "baddbmm" in _TORCH_OPS_REGISTRY:
del _TORCH_OPS_REGISTRY["baddbmm"]
@register_torch_op
def baddbmm(context, node):
"""
baddbmm(Tensor input, Tensor batch1, Tensor batch2, Scalar beta=1, Scalar alpha=1)
output = beta * input + alpha * batch1 * batch2
Notice that batch1 and batch2 must be 3-D tensors each containing the same number of matrices.
If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, then input must be broadcastable with a (b×n×p) tensor
and out will be a (b×n×p) tensor.
"""
assert len(node.outputs) == 1
inputs = _get_inputs(context, node, expected=5)
bias, batch1, batch2, beta, alpha = inputs
if beta.val != 1.0:
# Apply scaling factor beta to the bias.
if beta.val.dtype == np.int32:
beta = mb.cast(x=beta, dtype="fp32")
logger.warning(
f"Casted the `beta`(value={beta.val}) argument of `baddbmm` op "
"from int32 to float32 dtype for conversion!")
bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled")
context.add(bias)
if alpha.val != 1.0:
# Apply scaling factor alpha to the input.
batch1 = mb.mul(x=alpha, y=batch1, name=batch1.name + "_scaled")
context.add(batch1)
bmm_node = mb.matmul(x=batch1, y=batch2, name=node.name + "_bmm")
context.add(bmm_node)
baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name)
context.add(baddbmm_node)
def convert_vae_decoder(pipe, args):
""" Converts the VAE Decoder component of Stable Diffusion
"""
out_path = _get_out_path(args, "vae_decoder")
if os.path.exists(out_path):
logger.info(
f"`vae_decoder` already exists at {out_path}, skipping conversion."
)
return
if not hasattr(pipe, "unet"):
raise RuntimeError(
"convert_unet() deletes pipe.unet to save RAM. "
"Please use convert_vae_decoder() before convert_unet()")
z_shape = (
1, # B
pipe.vae.latent_channels, # C
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # w
)
sample_vae_decoder_inputs = {
"z": torch.rand(*z_shape, dtype=torch.float16)
}
class VAEDecoder(nn.Module):
""" Wrapper nn.Module wrapper for pipe.decode() method
"""
def __init__(self):
super().__init__()
self.post_quant_conv = pipe.vae.post_quant_conv
self.decoder = pipe.vae.decoder
def forward(self, z):
return self.decoder(self.post_quant_conv(z))
baseline_decoder = VAEDecoder().eval()
# No optimization needed for the VAE Decoder as it is a pure ConvNet
traced_vae_decoder = torch.jit.trace(
baseline_decoder, (sample_vae_decoder_inputs["z"].to(torch.float32), ))
modify_coremltools_torch_frontend_badbmm()
coreml_vae_decoder, out_path = _convert_to_coreml(
"vae_decoder", traced_vae_decoder, sample_vae_decoder_inputs,
["image"], args)
# Set model metadata
coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_vae_decoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_vae_decoder.version = args.model_version
coreml_vae_decoder.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_vae_decoder.input_description["z"] = \
"The denoised latent embeddings from the unet model after the last step of reverse diffusion"
# Set the output descriptions
coreml_vae_decoder.output_description[
"image"] = "Generated image normalized to range [-1, 1]"
_save_mlpackage(coreml_vae_decoder, out_path)
logger.info(f"Saved vae_decoder into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
baseline_out = baseline_decoder(
z=sample_vae_decoder_inputs["z"].to(torch.float32)).numpy()
coreml_out = list(
coreml_vae_decoder.predict(
{k: v.numpy()
for k, v in sample_vae_decoder_inputs.items()}).values())[0]
report_correctness(baseline_out, coreml_out,
"vae_decoder baseline PyTorch to baseline CoreML")
del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder
gc.collect()
def convert_unet(pipe, args):
""" Converts the UNet component of Stable Diffusion
"""
out_path = _get_out_path(args, "unet")
# Check if Unet was previously exported and then chunked
unet_chunks_exist = all(
os.path.exists(
out_path.replace(".mlpackage", f"_chunk{idx+1}.mlpackage"))
for idx in range(2))
if args.chunk_unet and unet_chunks_exist:
logger.info("`unet` chunks already exist, skipping conversion.")
del pipe.unet
gc.collect()
return
# If original Unet does not exist, export it from PyTorch+diffusers
elif not os.path.exists(out_path):
# Register the selected attention implementation globally
unet.ATTENTION_IMPLEMENTATION_IN_EFFECT = unet.AttentionImplementations[
args.attention_implementation]
logger.info(
f"Attention implementation in effect: {unet.ATTENTION_IMPLEMENTATION_IN_EFFECT}"
)
# Prepare sample input shapes and values
batch_size = 2 # for classifier-free guidance
sample_shape = (
batch_size, # B
pipe.unet.config.in_channels, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # W
)
if not hasattr(pipe, "text_encoder"):
raise RuntimeError(
"convert_text_encoder() deletes pipe.text_encoder to save RAM. "
"Please use convert_unet() before convert_text_encoder()")
encoder_hidden_states_shape = (
batch_size,
pipe.text_encoder.config.hidden_size,
1,
pipe.text_encoder.config.max_position_embeddings,
)
# Create the scheduled timesteps for downstream use
DEFAULT_NUM_INFERENCE_STEPS = 50
pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS)
sample_unet_inputs = OrderedDict([
("sample", torch.rand(*sample_shape)),
("timestep",
torch.tensor([pipe.scheduler.timesteps[0].item()] *
(batch_size)).to(torch.float32)),
("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape))
])
sample_unet_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_unet_inputs.items()
}
logger.info(f"Sample inputs spec: {sample_unet_inputs_spec}")
# Initialize reference unet
reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval()
load_state_dict_summary = reference_unet.load_state_dict(
pipe.unet.state_dict())
# Prepare inputs
baseline_sample_unet_inputs = deepcopy(sample_unet_inputs)
baseline_sample_unet_inputs[
"encoder_hidden_states"] = baseline_sample_unet_inputs[
"encoder_hidden_states"].squeeze(2).transpose(1, 2)
# JIT trace
logger.info("JIT tracing..")
reference_unet = torch.jit.trace(reference_unet,
list(sample_unet_inputs.values()))
logger.info("Done.")
if args.check_output_correctness:
baseline_out = pipe.unet(**baseline_sample_unet_inputs,
return_dict=False)[0].numpy()
reference_out = reference_unet(**sample_unet_inputs)[0].numpy()
report_correctness(baseline_out, reference_out,
"unet baseline to reference PyTorch")
del pipe.unet
gc.collect()
coreml_sample_unet_inputs = {
k: v.numpy().astype(np.float16)
for k, v in sample_unet_inputs.items()
}
coreml_unet, out_path = _convert_to_coreml("unet", reference_unet,
coreml_sample_unet_inputs,
["noise_pred"], args)
del reference_unet
gc.collect()
# Set model metadata
coreml_unet.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_unet.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_unet.version = args.model_version
coreml_unet.short_description = \
"Stable Diffusion generates images conditioned on text or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_unet.input_description["sample"] = \
"The low resolution latent feature maps being denoised through reverse diffusion"
coreml_unet.input_description["timestep"] = \
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"
coreml_unet.input_description["encoder_hidden_states"] = \
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \
"Shorter text does not reduce computation."
# Set the output descriptions
coreml_unet.output_description["noise_pred"] = \
"Same shape and dtype as the `sample` input. " \
"The predicted noise to facilitate the reverse diffusion (denoising) process"
_save_mlpackage(coreml_unet, out_path)
logger.info(f"Saved unet into {out_path}")
# Parity check PyTorch vs CoreML
if args.check_output_correctness:
coreml_out = list(
coreml_unet.predict(coreml_sample_unet_inputs).values())[0]
report_correctness(baseline_out, coreml_out,
"unet baseline PyTorch to reference CoreML")
del coreml_unet
gc.collect()
else:
del pipe.unet
gc.collect()
logger.info(
f"`unet` already exists at {out_path}, skipping conversion.")
if args.chunk_unet and not unet_chunks_exist:
logger.info("Chunking unet in two approximately equal MLModels")
args.mlpackage_path = out_path
args.remove_original = False
chunk_mlprogram.main(args)
def convert_safety_checker(pipe, args):
""" Converts the Safety Checker component of Stable Diffusion
"""
if pipe.safety_checker is None:
logger.warning(
f"diffusers pipeline for {args.model_version} does not have a `safety_checker` module! " \
"`--convert-safety-checker` will be ignored."
)
return
out_path = _get_out_path(args, "safety_checker")
if os.path.exists(out_path):
logger.info(
f"`safety_checker` already exists at {out_path}, skipping conversion."
)
return
sample_image = np.random.randn(
1, # B
pipe.vae.config.sample_size, # H
pipe.vae.config.sample_size, # w
3 # C
).astype(np.float32)
# Note that pipe.feature_extractor is not an ML model. It simply
# preprocesses data for the pipe.safety_checker module.
safety_checker_input = pipe.feature_extractor(
pipe.numpy_to_pil(sample_image),
return_tensors="pt",
).pixel_values.to(torch.float32)
sample_safety_checker_inputs = OrderedDict([
("clip_input", safety_checker_input),
("images", torch.from_numpy(sample_image)),
("adjustment", torch.tensor([0]).to(torch.float32)),
])
sample_safety_checker_inputs_spec = {
k: (v.shape, v.dtype)
for k, v in sample_safety_checker_inputs.items()
}
logger.info(f"Sample inputs spec: {sample_safety_checker_inputs_spec}")
# Patch safety_checker's forward pass to be vectorized and avoid conditional blocks
# (similar to pipe.safety_checker.forward_onnx)
from diffusers.pipelines.stable_diffusion import safety_checker
def forward_coreml(self, clip_input, images, adjustment):
""" Forward pass implementation for safety_checker
"""
def cosine_distance(image_embeds, text_embeds):
return F.normalize(image_embeds) @ F.normalize(
text_embeds).transpose(0, 1)
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds,
self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
special_care = special_scores.gt(0).float().sum(dim=1).gt(0).float()
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(
-1, cos_dist.shape[1])
concept_scores = (cos_dist -
self.concept_embeds_weights) + special_adjustment
has_nsfw_concepts = concept_scores.gt(0).float().sum(dim=1).gt(0)[:,
None,
None,
None]
has_nsfw_concepts_inds, _ = torch.broadcast_tensors(
has_nsfw_concepts, images)
images[has_nsfw_concepts_inds] = 0.0 # black image
return images, has_nsfw_concepts.float(), concept_scores
baseline_safety_checker = deepcopy(pipe.safety_checker.eval())
setattr(baseline_safety_checker, "forward",
MethodType(forward_coreml, baseline_safety_checker))
# In order to parity check the actual signal, we need to override the forward pass to return `concept_scores` which is the
# output before thresholding
# Reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L100
def forward_extended_return(self, clip_input, images, adjustment):
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = F.normalize(image_embeds)
normalized_text_embeds = F.normalize(text_embeds)
return torch.mm(normalized_image_embeds,
normalized_text_embeds.t())
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds,
self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(
-1, cos_dist.shape[1])
concept_scores = (cos_dist -
self.concept_embeds_weights) + special_adjustment
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
images[has_nsfw_concepts] = 0.0
return images, has_nsfw_concepts, concept_scores
setattr(pipe.safety_checker, "forward",
MethodType(forward_extended_return, pipe.safety_checker))
# Trace the safety_checker model
logger.info("JIT tracing..")
traced_safety_checker = torch.jit.trace(
baseline_safety_checker, list(sample_safety_checker_inputs.values()))
logger.info("Done.")
del baseline_safety_checker
gc.collect()
# Cast all inputs to float16
coreml_sample_safety_checker_inputs = {
k: v.numpy().astype(np.float16)
for k, v in sample_safety_checker_inputs.items()
}
# Convert safety_checker model to Core ML
coreml_safety_checker, out_path = _convert_to_coreml(
"safety_checker", traced_safety_checker,
coreml_sample_safety_checker_inputs,
["filtered_images", "has_nsfw_concepts", "concept_scores"], args)
# Set model metadata
coreml_safety_checker.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
coreml_safety_checker.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
coreml_safety_checker.version = args.model_version
coreml_safety_checker.short_description = \
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
"Please refer to https://arxiv.org/abs/2112.10752 for details."
# Set the input descriptions
coreml_safety_checker.input_description["clip_input"] = \
"The normalized image input tensor resized to (224x224) in channels-first (BCHW) format"
coreml_safety_checker.input_description["images"] = \
f"Output of the vae_decoder ({pipe.vae.config.sample_size}x{pipe.vae.config.sample_size}) in channels-last (BHWC) format"
coreml_safety_checker.input_description["adjustment"] = \
"Bias added to the concept scores to trade off increased recall for reduce precision in the safety checker classifier"
# Set the output descriptions
coreml_safety_checker.output_description["filtered_images"] = \
f"Identical to the input `images`. If safety checker detected any sensitive content, " \
"the corresponding image is replaced with a blank image (zeros)"
coreml_safety_checker.output_description["has_nsfw_concepts"] = \
"Indicates whether the safety checker model found any sensitive content in the given image"
coreml_safety_checker.output_description["concept_scores"] = \
"Concept scores are the scores before thresholding at zero yields the `has_nsfw_concepts` output. " \
"These scores can be used to tune the `adjustment` input"
_save_mlpackage(coreml_safety_checker, out_path)
if args.check_output_correctness:
baseline_out = pipe.safety_checker(
**sample_safety_checker_inputs)[2].numpy()
coreml_out = coreml_safety_checker.predict(
coreml_sample_safety_checker_inputs)["concept_scores"]
report_correctness(
baseline_out, coreml_out,
"safety_checker baseline PyTorch to reference CoreML")
del traced_safety_checker, coreml_safety_checker, pipe.safety_checker
gc.collect()
def main(args):
os.makedirs(args.o, exist_ok=True)
# Instantiate diffusers pipe as reference
logger.info(
f"Initializing StableDiffusionPipeline with {args.model_version}..")
pipe = StableDiffusionPipeline.from_pretrained(args.model_version,
use_auth_token=True)
logger.info("Done.")
# Convert models
if args.convert_vae_decoder:
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")
if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)
logger.info("Converted unet")
if args.convert_text_encoder:
logger.info("Converting text_encoder")
convert_text_encoder(pipe, args)
logger.info("Converted text_encoder")
if args.convert_safety_checker:
logger.info("Converting safety_checker")
convert_safety_checker(pipe, args)
logger.info("Converted safety_checker")
if args.bundle_resources_for_swift_cli:
logger.info("Bundling resources for the Swift CLI")
bundle_resources_for_swift_cli(args)
logger.info("Bundled resources for the Swift CLI")
if args.quantize_weights_to_8bits:
# Note: Not recommended, significantly degrades generated image quality
logger.info("Quantizing weights to 8-bit precision")
quantize_weights_to_8bits(args)
logger.info("Quantized weights to 8-bit precision")
def parser_spec():
parser = argparse.ArgumentParser()
# Select which models to export (All are needed for text-to-image pipeline to function)
parser.add_argument("--convert-text-encoder", action="store_true")
parser.add_argument("--convert-vae-decoder", action="store_true")
parser.add_argument("--convert-unet", action="store_true")
parser.add_argument("--convert-safety-checker", action="store_true")
parser.add_argument(
"--model-version",
default="CompVis/stable-diffusion-v1-4",
help=
("The pre-trained model checkpoint and configuration to restore. "
"For available versions: https://huggingface.co/models?search=stable-diffusion"
))
parser.add_argument("--compute-unit",
choices=tuple(cu
for cu in ct.ComputeUnit._member_names_),
default="ALL")
parser.add_argument(
"--latent-h",
type=int,
default=None,
help=
"The spatial resolution (number of rows) of the latent space. `Defaults to pipe.unet.config.sample_size`",
)
parser.add_argument(
"--latent-w",
type=int,
default=None,
help=
"The spatial resolution (number of cols) of the latent space. `Defaults to pipe.unet.config.sample_size`",
)
parser.add_argument(
"--attention-implementation",
choices=tuple(ai
for ai in unet.AttentionImplementations._member_names_),
default=unet.ATTENTION_IMPLEMENTATION_IN_EFFECT.name,
help=
"The enumerated implementations trade off between ANE and GPU performance",
)
parser.add_argument(
"-o",
default=os.getcwd(),
help="The resulting mlpackages will be saved into this directory")
parser.add_argument(
"--check-output-correctness",
action="store_true",
help=
("If specified, compares the outputs of original PyTorch and final CoreML models and reports PSNR in dB. ",
"Enabling this feature uses more memory. Disable it if your machine runs out of memory."
))
parser.add_argument(
"--chunk-unet",
action="store_true",
help=
("If specified, generates two mlpackages out of the unet model which approximately equal weights sizes. "
"This is required for ANE deployment on iOS and iPadOS. Not required for macOS."
))
parser.add_argument(
"--quantize-weights-to-8bits",
action="store_true",
help=
("If specified, quantize 16-bits weights to 8-bits weights in-place for all models. "
"Not recommended as the generated image quality degraded significantly after 8-bit weight quantization"
))
# Swift CLI Resource Bundling
parser.add_argument(
"--bundle-resources-for-swift-cli",
action="store_true",
help=
("If specified, creates a resources directory compatible with the sample Swift CLI. "
"It compiles all four models and adds them to a StableDiffusionResources directory "
"along with a `vocab.json` and `merges.txt` for the text tokenizer"))
parser.add_argument(
"--text-encoder-vocabulary-url",
default=
"https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/vocab.json",
help="The URL to the vocabulary file use by the text tokenizer")
parser.add_argument(
"--text-encoder-merges-url",
default=
"https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt",
help="The URL to the merged pairs used in by the text tokenizer.")
return parser
if __name__ == "__main__":
parser = parser_spec()
args = parser.parse_args()
main(args)