pr comments

pull/115/head
Timothy Kautz 1 year ago
parent f28d64c788
commit 45b36be2dc

@ -121,7 +121,7 @@ pip install -e .
**Step 4:** Execute the following command from the Terminal to generate Core ML model files (`.mlpackage`)
```shell
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-text-encoder --convert-vae-decoder --convert-vae-encoder --convert-safety-checker -o <output-mlpackages-directory>
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-text-encoder --convert-vae-decoder --convert-safety-checker -o <output-mlpackages-directory>
```
**WARNING:** This command will download several GB worth of PyTorch checkpoints from Hugging Face. Please ensure that you are on Wi-Fi and have enough disk space.
@ -222,7 +222,7 @@ Both of these products require the Core ML models and tokenization resources to
Optionally, for image2image, in-painting, or similar:
- `VAEEnecoder.mlmodelc` (image encoder model)
- `VAEEncoder.mlmodelc` (image encoder model)
Optionally, it may also include the safety checker model that some versions of Stable Diffusion include:

@ -45,7 +45,8 @@ def _get_coreml_inputs(sample_inputs, args):
# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
# as implemented in vae.py as part of the AutoencoderKL class
# This is because coreml tools does not support the `randn` operation, so we pass in a random tensor.
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
# coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
class CoreMLDiagonalGaussianDistribution(object):
def __init__(self, parameters, noise):
self.parameters = parameters
@ -55,10 +56,7 @@ class CoreMLDiagonalGaussianDistribution(object):
self.std = torch.exp(0.5 * self.logvar)
def sample(self) -> torch.FloatTensor:
device = self.parameters.device
# make sure sample is on the same device as the parameters and has same dtype
sample = self.noise.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
x = self.mean + self.std * self.noise
return x
def compute_psnr(a, b):
@ -489,8 +487,8 @@ def convert_vae_encoder(pipe, args):
sample_shape = (
1, # B
3, # C (RGB range from -1 to 1)
args.latent_h or pipe.unet.config.sample_size * 8, # H
args.latent_w or pipe.unet.config.sample_size * 8, # w
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
)
noise_shape = (
@ -505,15 +503,15 @@ def convert_vae_encoder(pipe, args):
1,
)
sqrtAlphasCumprodTorchShape = torch.tensor([[0.2,]])
sqrtOneMinusAlphasCumprodTorchShape = torch.tensor([[0.8,]])
sqrt_alphas_cumprod_torch_shape = torch.tensor([[0.2,]])
sqrt_one_minus_alphas_cumprod_torch_shape = torch.tensor([[0.8,]])
sample_vae_encoder_inputs = {
"sample": torch.rand(*sample_shape, dtype=torch.float16),
"diagonalNoise": torch.rand(*noise_shape, dtype=torch.float16),
"diagonal_noise": torch.rand(*noise_shape, dtype=torch.float16),
"noise": torch.rand(*noise_shape, dtype=torch.float16),
"sqrtAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
"sqrtOneMinusAlphasCumprod": torch.rand(*float_value_shape, dtype=torch.float16),
"sqrt_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
"sqrt_one_minus_alphas_cumprod": torch.rand(*float_value_shape, dtype=torch.float16),
}
class VAEEncoder(nn.Module):
@ -528,31 +526,27 @@ def convert_vae_encoder(pipe, args):
# Because CoreMLTools does not support the torch.randn op, we pass in both
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
# the noise tensor combined with precalculated `sqrtAlphasCumprod` and `sqrtOneMinusAlphasCumprod`
# the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
# for faster computation.
def forward(self, sample, diagonalNoise, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod):
def forward(self, sample, diagonal_noise, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
h = self.encoder(sample)
moments = self.quant_conv(h)
diagonalNoise = diagonalNoise.to(sample.device)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonalNoise)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
posteriorSample = posterior.sample()
# Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample
result = self.add_noise(init_latents, noise, sqrtAlphasCumprod, sqrtOneMinusAlphasCumprod)
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
return result
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
sqrtAlphasCumprod: torch.FloatTensor,
sqrtOneMinusAlphasCumprod: torch.FloatTensor
sqrt_alphas_cumprod: torch.FloatTensor,
sqrt_one_minus_alphas_cumprod: torch.FloatTensor
) -> torch.FloatTensor:
noise = noise.to(original_samples.device)
sqrtAlphasCumprod = sqrtAlphasCumprod.to(original_samples.device)
sqrtOneMinusAlphasCumprod = sqrtOneMinusAlphasCumprod.to(original_samples.device)
noisy_samples = sqrtAlphasCumprod * original_samples + sqrtOneMinusAlphasCumprod * noise
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
return noisy_samples
@ -562,10 +556,10 @@ def convert_vae_encoder(pipe, args):
traced_vae_encoder = torch.jit.trace(
baseline_encoder, (
sample_vae_encoder_inputs["sample"].to(torch.float32),
sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrtAlphasCumprodTorchShape.to(torch.float32),
sqrtOneMinusAlphasCumprodTorchShape.to(torch.float32)
sqrt_alphas_cumprod_torch_shape.to(torch.float32),
sqrt_one_minus_alphas_cumprod_torch_shape.to(torch.float32)
))
modify_coremltools_torch_frontend_badbmm()
@ -584,18 +578,18 @@ def convert_vae_encoder(pipe, args):
# Set the input descriptions
coreml_vae_encoder.input_description["sample"] = \
"An image of the correct size to create the latent space with, image2image and in-painting."
coreml_vae_encoder.input_description["diagonalNoise"] = \
coreml_vae_encoder.input_description["diagonal_noise"] = \
"Latent noise for `DiagonalGaussianDistribution` operation."
coreml_vae_encoder.input_description["noise"] = \
"Latent noise for use with strength parameter of image2image"
coreml_vae_encoder.input_description["sqrtAlphasCumprod"] = \
"Precalculated `sqrtAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"
coreml_vae_encoder.input_description["sqrtOneMinusAlphasCumprod"] = \
"Precalculated `sqrtOneMinusAlphasCumprod` value based on strength and the current schedular's alphasCumprod values"
coreml_vae_encoder.input_description["sqrt_alphas_cumprod"] = \
"Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
coreml_vae_encoder.input_description["sqrt_one_minus_alphas_cumprod"] = \
"Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular's alphasCumprod values"
# Set the output descriptions
coreml_vae_encoder.output_description[
"latent_dist"] = "The latent embeddings from the unet model from the input image for image2image."
"latent_dist"] = "The latent embeddings from the unet model from the input image."
_save_mlpackage(coreml_vae_encoder, out_path)
@ -605,20 +599,20 @@ def convert_vae_encoder(pipe, args):
if args.check_output_correctness:
baseline_out = baseline_encoder(
sample=sample_vae_encoder_inputs["sample"].to(torch.float32),
diagonalNoise=sample_vae_encoder_inputs["diagonalNoise"].to(torch.float32),
diagonal_noise=sample_vae_encoder_inputs["diagonal_noise"].to(torch.float32),
noise=sample_vae_encoder_inputs["noise"].to(torch.float32),
sqrtAlphasCumprod=sqrtAlphasCumprodTorchShape,
sqrtOneMinusAlphasCumprod=sqrtOneMinusAlphasCumprodTorchShape,
sqrt_alphas_cumprod=sqrt_alphas_cumprod_torch_shape,
sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod_torch_shape,
).numpy(),
coreml_out = list(
coreml_vae_encoder.predict(
{
"sample": sample_vae_encoder_inputs["sample"].numpy(),
"diagonalNoise": sample_vae_encoder_inputs["diagonalNoise"].numpy(),
"diagonal_noise": sample_vae_encoder_inputs["diagonal_noise"].numpy(),
"noise": sample_vae_encoder_inputs["noise"].numpy(),
"sqrtAlphasCumprod": sqrtAlphasCumprodTorchShape.numpy(),
"sqrtOneMinusAlphasCumprod": sqrtOneMinusAlphasCumprodTorchShape.numpy()
"sqrt_alphas_cumprod": sqrt_alphas_cumprod_torch_shape.numpy(),
"sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod_torch_shape.numpy()
}).values())
report_correctness(baseline_out[0], coreml_out[0],

Loading…
Cancel
Save