|
|
|
@ -43,6 +43,23 @@ def _get_coreml_inputs(sample_inputs, args):
|
|
|
|
|
) for k, v in sample_inputs.items()
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
|
|
|
|
|
# as implemented in vae.py as part of the AutoencoderKL class
|
|
|
|
|
# This is because coreml tools does not support the `randn` operation, so we pass in a random tensor.
|
|
|
|
|
class CoreMLDiagonalGaussianDistribution(object):
|
|
|
|
|
def __init__(self, parameters, noise):
|
|
|
|
|
self.parameters = parameters
|
|
|
|
|
self.noise = noise
|
|
|
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
|
|
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
|
|
|
self.std = torch.exp(0.5 * self.logvar)
|
|
|
|
|
|
|
|
|
|
def sample(self) -> torch.FloatTensor:
|
|
|
|
|
device = self.parameters.device
|
|
|
|
|
# make sure sample is on the same device as the parameters and has same dtype
|
|
|
|
|
sample = self.noise.to(device=device, dtype=self.parameters.dtype)
|
|
|
|
|
x = self.mean + self.std * sample
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def compute_psnr(a, b):
|
|
|
|
|
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
|
|
|
|
@ -140,7 +157,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
|
|
|
|
|
|
|
|
|
|
def quantize_weights_to_8bits(args):
|
|
|
|
|
for model_name in [
|
|
|
|
|
"text_encoder", "vae_decoder", "unet", "unet_chunk1",
|
|
|
|
|
"text_encoder", "vae_decoder", "vae_encoder", "unet", "unet_chunk1",
|
|
|
|
|
"unet_chunk2", "safety_checker"
|
|
|
|
|
]:
|
|
|
|
|
out_path = _get_out_path(args, model_name)
|
|
|
|
@ -190,6 +207,7 @@ def bundle_resources_for_swift_cli(args):
|
|
|
|
|
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
|
|
|
|
|
for source_name, target_name in [("text_encoder", "TextEncoder"),
|
|
|
|
|
("vae_decoder", "VAEDecoder"),
|
|
|
|
|
("vae_encoder", "VAEEncoder"),
|
|
|
|
|
("unet", "Unet"),
|
|
|
|
|
("unet_chunk1", "UnetChunk1"),
|
|
|
|
|
("unet_chunk2", "UnetChunk2"),
|
|
|
|
@ -453,6 +471,163 @@ def convert_vae_decoder(pipe, args):
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_vae_encoder(pipe, args):
|
|
|
|
|
""" Converts the VAE Encoder component of Stable Diffusion
|
|
|
|
|
"""
|
|
|
|
|
out_path = _get_out_path(args, "vae_encoder")
|
|
|
|
|
if os.path.exists(out_path):
|
|
|
|
|
logger.info(
|
|
|
|
|
f"`vae_encoder` already exists at {out_path}, skipping conversion."
|
|
|
|
|
)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not hasattr(pipe, "unet"):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"convert_unet() deletes pipe.unet to save RAM. "
|
|
|
|
|
"Please use convert_vae_encoder() before convert_unet()")
|
|
|
|
|
|
|
|
|
|
sample_shape = (
|
|
|
|
|
1, # B
|
|
|
|
|
3, # C (RGB range from -1 to 1)
|
|
|
|
|
args.latent_h or pipe.unet.config.sample_size * 8, # H
|
|
|
|
|
args.latent_w or pipe.unet.config.sample_size * 8, # w
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noise_shape = (
|
|
|
|
|
1, # B
|
|
|
|
|
4, # C
|
|
|
|
|
pipe.unet.config.sample_size, # H
|
|
|
|
|
pipe.unet.config.sample_size, # w
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
float_value_shape = (
|
|
|
|
|
1,
|
|
|
|
|
1,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
sqrtAlphasCumprodTorchShape = torch.tensor([[0.2,]])
|
|
|
|
|
sqrtOneMinusAlphasCumprodTorchShape = torch.tensor([[0.8,]])
|
|
|
|
|
|
|
|
|
|
sample_vae_encoder_inputs = {
|
|
|
|
|
"sample": torch.rand(*sample_shape, dtype=torch.float16),
|
|
|
|
|
"diagonalNoise": torch.rand(*noise_shape, dtype=torch.float16),
|
|
|
|
|
"noise": torch.rand(*noise_shape, dtype=torch.float16),
|
|
|
|
|
"sqrtAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
|
|
|
|
|
"sqrtOneMinusAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class VAEEncoder(nn.Module):
|
|
|
|
|
""" Wrapper nn.Module wrapper for pipe.encode() method
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.quant_conv = pipe.vae.quant_conv
|
|
|
|
|
self.alphas_cumprod = pipe.scheduler.alphas_cumprod
|
|
|
|
|
self.encoder = pipe.vae.encoder
|
|
|
|
|
|
|
|
|
|
# Because CoreMLTools does not support the torch.randn op, we pass in both
|
|
|
|
|
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
|
|
|
|
|
# the noise tensor combined with precalculated `sqrtAlphasCumprod` and `sqrtOneMinusAlphasCumprod`
|
|
|
|
|
# for faster computation.
|
|
|
|
|
def forward(self, sample, diagonalNoise, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod):
|
|
|
|
|
h = self.encoder(sample)
|
|
|
|
|
moments = self.quant_conv(h)
|
|
|
|
|
diagonalNoise = diagonalNoise.to(sample.device)
|
|
|
|
|
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonalNoise)
|
|
|
|
|
posteriorSample = posterior.sample()
|
|
|
|
|
|
|
|
|
|
# Add the scaling operation and the latent noise for faster computation
|
|
|
|
|
init_latents = 0.18215 * posteriorSample
|
|
|
|
|
result = self.add_noise(init_latents, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def add_noise(
|
|
|
|
|
self,
|
|
|
|
|
original_samples: torch.FloatTensor,
|
|
|
|
|
noise: torch.FloatTensor,
|
|
|
|
|
sqrtAlphasCumprod: torch.FloatTensor,
|
|
|
|
|
sqrtOneMinusAlphasCumprod: torch.FloatTensor
|
|
|
|
|
) -> torch.FloatTensor:
|
|
|
|
|
noise = noise.to(original_samples.device)
|
|
|
|
|
sqrtAlphasCumprod = sqrtAlphasCumprod.to(original_samples.device)
|
|
|
|
|
sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod.to(original_samples.device)
|
|
|
|
|
noisy_samples = sqrtAlphasCumprod * original_samples + sqrtOneMinusAlphasCumprod * noise
|
|
|
|
|
return noisy_samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
baseline_encoder = VAEEncoder().eval()
|
|
|
|
|
|
|
|
|
|
# No optimization needed for the VAE Encoder as it is a pure ConvNet
|
|
|
|
|
traced_vae_encoder = torch.jit.trace(
|
|
|
|
|
baseline_encoder, (
|
|
|
|
|
sample_vae_encoder_inputs["sample"].to(torch.float32),
|
|
|
|
|
sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
|
|
|
|
|
sample_vae_encoder_inputs["noise"].to(torch.float32),
|
|
|
|
|
sqrtAlphasCumprodTorchShape.to(torch.float32),
|
|
|
|
|
sqrtOneMinusAlphasCumprodTorchShape.to(torch.float32)
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
modify_coremltools_torch_frontend_badbmm()
|
|
|
|
|
coreml_vae_encoder, out_path = _convert_to_coreml(
|
|
|
|
|
"vae_encoder", traced_vae_encoder, sample_vae_encoder_inputs,
|
|
|
|
|
["latent_dist"], args)
|
|
|
|
|
|
|
|
|
|
# Set model metadata
|
|
|
|
|
coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
|
|
|
|
|
coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)"
|
|
|
|
|
coreml_vae_encoder.version = args.model_version
|
|
|
|
|
coreml_vae_encoder.short_description = \
|
|
|
|
|
"Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
|
|
|
|
|
"Please refer to https://arxiv.org/abs/2112.10752 for details."
|
|
|
|
|
|
|
|
|
|
# Set the input descriptions
|
|
|
|
|
coreml_vae_encoder.input_description["sample"] = \
|
|
|
|
|
"An image of the correct size to create the latent space with, image2image and in-painting."
|
|
|
|
|
coreml_vae_encoder.input_description["diagonalNoise"] = \
|
|
|
|
|
"Latent noise for `DiagonalGaussianDistribution` operation."
|
|
|
|
|
coreml_vae_encoder.input_description["noise"] = \
|
|
|
|
|
"Latent noise for use with strength parameter of image2image"
|
|
|
|
|
coreml_vae_encoder.input_description["sqrtAlphasCumprod"] = \
|
|
|
|
|
"Precalculated `sqrtAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"
|
|
|
|
|
coreml_vae_encoder.input_description["sqrtOneMinusAlphasCumprod"] = \
|
|
|
|
|
"Precalculated `sqrtOneMinusAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"
|
|
|
|
|
|
|
|
|
|
# Set the output descriptions
|
|
|
|
|
coreml_vae_encoder.output_description[
|
|
|
|
|
"latent_dist"] = "The latent embeddings from the unet model from the input image for image2image."
|
|
|
|
|
|
|
|
|
|
_save_mlpackage(coreml_vae_encoder, out_path)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Saved vae_encoder into {out_path}")
|
|
|
|
|
|
|
|
|
|
# Parity check PyTorch vs CoreML
|
|
|
|
|
if args.check_output_correctness:
|
|
|
|
|
baseline_out = baseline_encoder(
|
|
|
|
|
sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
|
|
|
|
|
diagonalNoise=sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
|
|
|
|
|
noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
|
|
|
|
|
sqrtAlphasCumprod=sqrtAlphasCumprodTorchShape,
|
|
|
|
|
sqrtOneMinusAlphasCumprod=sqrtOneMinusAlphasCumprodTorchShape,
|
|
|
|
|
).numpy(),
|
|
|
|
|
|
|
|
|
|
coreml_out = list(
|
|
|
|
|
coreml_vae_encoder.predict(
|
|
|
|
|
{
|
|
|
|
|
"sample": sample_vae_encoder_inputs["sample"].numpy(),
|
|
|
|
|
"diagonalNoise": sample_vae_encoder_inputs["diagonalNoise"].numpy(),
|
|
|
|
|
"noise": sample_vae_encoder_inputs["noise"].numpy(),
|
|
|
|
|
"sqrtAlphasCumprod": sqrtAlphasCumprodTorchShape.numpy(),
|
|
|
|
|
"sqrtOneMinusAlphasCumprod": sqrtOneMinusAlphasCumprodTorchShape.numpy()
|
|
|
|
|
}).values())
|
|
|
|
|
|
|
|
|
|
report_correctness(baseline_out[0], coreml_out[0],
|
|
|
|
|
"vae_encoder baseline PyTorch to baseline CoreML")
|
|
|
|
|
|
|
|
|
|
del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_unet(pipe, args):
|
|
|
|
|
""" Converts the UNet component of Stable Diffusion
|
|
|
|
|
"""
|
|
|
|
@ -802,6 +977,11 @@ def main(args):
|
|
|
|
|
convert_vae_decoder(pipe, args)
|
|
|
|
|
logger.info("Converted vae_decoder")
|
|
|
|
|
|
|
|
|
|
if args.convert_vae_encoder:
|
|
|
|
|
logger.info("Converting vae_encoder")
|
|
|
|
|
convert_vae_encoder(pipe, args)
|
|
|
|
|
logger.info("Converted vae_encoder")
|
|
|
|
|
|
|
|
|
|
if args.convert_unet:
|
|
|
|
|
logger.info("Converting unet")
|
|
|
|
|
convert_unet(pipe, args)
|
|
|
|
@ -835,6 +1015,7 @@ def parser_spec():
|
|
|
|
|
# Select which models to export (All are needed for text-to-image pipeline to function)
|
|
|
|
|
parser.add_argument("--convert-text-encoder", action="store_true")
|
|
|
|
|
parser.add_argument("--convert-vae-decoder", action="store_true")
|
|
|
|
|
parser.add_argument("--convert-vae-encoder", action="store_true")
|
|
|
|
|
parser.add_argument("--convert-unet", action="store_true")
|
|
|
|
|
parser.add_argument("--convert-safety-checker", action="store_true")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|