@ -43,6 +43,21 @@ def _get_coreml_inputs(sample_inputs, args):
) for k , v in sample_inputs . items ( )
) for k , v in sample_inputs . items ( )
]
]
# Simpler version of `DiagonalGaussianDistribution` with only needed calculations
# as implemented in vae.py as part of the AutoencoderKL class
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L312
# coremltools-6.1 does not yet implement the randn operation with the option of setting a random seed
class CoreMLDiagonalGaussianDistribution ( object ) :
def __init__ ( self , parameters , noise ) :
self . parameters = parameters
self . noise = noise
self . mean , self . logvar = torch . chunk ( parameters , 2 , dim = 1 )
self . logvar = torch . clamp ( self . logvar , - 30.0 , 20.0 )
self . std = torch . exp ( 0.5 * self . logvar )
def sample ( self ) - > torch . FloatTensor :
x = self . mean + self . std * self . noise
return x
def compute_psnr ( a , b ) :
def compute_psnr ( a , b ) :
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
""" Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects
@ -140,7 +155,7 @@ def _convert_to_coreml(submodule_name, torchscript_module, sample_inputs,
def quantize_weights_to_8bits ( args ) :
def quantize_weights_to_8bits ( args ) :
for model_name in [
for model_name in [
" text_encoder " , " vae_decoder " , " unet" , " unet_chunk1 " ,
" text_encoder " , " vae_decoder " , " vae_encoder" , " unet" , " unet_chunk1 " ,
" unet_chunk2 " , " safety_checker "
" unet_chunk2 " , " safety_checker "
] :
] :
out_path = _get_out_path ( args , model_name )
out_path = _get_out_path ( args , model_name )
@ -190,6 +205,7 @@ def bundle_resources_for_swift_cli(args):
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
# Compile model using coremlcompiler (Significantly reduces the load time for unet)
for source_name , target_name in [ ( " text_encoder " , " TextEncoder " ) ,
for source_name , target_name in [ ( " text_encoder " , " TextEncoder " ) ,
( " vae_decoder " , " VAEDecoder " ) ,
( " vae_decoder " , " VAEDecoder " ) ,
( " vae_encoder " , " VAEEncoder " ) ,
( " unet " , " Unet " ) ,
( " unet " , " Unet " ) ,
( " unet_chunk1 " , " UnetChunk1 " ) ,
( " unet_chunk1 " , " UnetChunk1 " ) ,
( " unet_chunk2 " , " UnetChunk2 " ) ,
( " unet_chunk2 " , " UnetChunk2 " ) ,
@ -453,6 +469,159 @@ def convert_vae_decoder(pipe, args):
gc . collect ( )
gc . collect ( )
def convert_vae_encoder ( pipe , args ) :
""" Converts the VAE Encoder component of Stable Diffusion
"""
out_path = _get_out_path ( args , " vae_encoder " )
if os . path . exists ( out_path ) :
logger . info (
f " `vae_encoder` already exists at { out_path } , skipping conversion. "
)
return
if not hasattr ( pipe , " unet " ) :
raise RuntimeError (
" convert_unet() deletes pipe.unet to save RAM. "
" Please use convert_vae_encoder() before convert_unet() " )
sample_shape = (
1 , # B
3 , # C (RGB range from -1 to 1)
( args . latent_h or pipe . unet . config . sample_size ) * 8 , # H
( args . latent_w or pipe . unet . config . sample_size ) * 8 , # w
)
noise_shape = (
1 , # B
4 , # C
pipe . unet . config . sample_size , # H
pipe . unet . config . sample_size , # w
)
float_value_shape = (
1 ,
1 ,
)
sqrt_alphas_cumprod_torch_shape = torch . tensor ( [ [ 0.2 , ] ] )
sqrt_one_minus_alphas_cumprod_torch_shape = torch . tensor ( [ [ 0.8 , ] ] )
sample_vae_encoder_inputs = {
" sample " : torch . rand ( * sample_shape , dtype = torch . float16 ) ,
" diagonal_noise " : torch . rand ( * noise_shape , dtype = torch . float16 ) ,
" noise " : torch . rand ( * noise_shape , dtype = torch . float16 ) ,
" sqrt_alphas_cumprod " : torch . rand ( * float_value_shape , dtype = torch . float16 ) ,
" sqrt_one_minus_alphas_cumprod " : torch . rand ( * float_value_shape , dtype = torch . float16 ) ,
}
class VAEEncoder ( nn . Module ) :
""" Wrapper nn.Module wrapper for pipe.encode() method
"""
def __init__ ( self ) :
super ( ) . __init__ ( )
self . quant_conv = pipe . vae . quant_conv
self . alphas_cumprod = pipe . scheduler . alphas_cumprod
self . encoder = pipe . vae . encoder
# Because CoreMLTools does not support the torch.randn op, we pass in both
# the diagonal Noise for the `DiagonalGaussianDistribution` operation and
# the noise tensor combined with precalculated `sqrt_alphas_cumprod` and `sqrt_one_minus_alphas_cumprod`
# for faster computation.
def forward ( self , sample , diagonal_noise , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod ) :
h = self . encoder ( sample )
moments = self . quant_conv ( h )
posterior = CoreMLDiagonalGaussianDistribution ( moments , diagonal_noise )
posteriorSample = posterior . sample ( )
# Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample
result = self . add_noise ( init_latents , noise , sqrt_alphas_cumprod , sqrt_one_minus_alphas_cumprod )
return result
def add_noise (
self ,
original_samples : torch . FloatTensor ,
noise : torch . FloatTensor ,
sqrt_alphas_cumprod : torch . FloatTensor ,
sqrt_one_minus_alphas_cumprod : torch . FloatTensor
) - > torch . FloatTensor :
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
return noisy_samples
baseline_encoder = VAEEncoder ( ) . eval ( )
# No optimization needed for the VAE Encoder as it is a pure ConvNet
traced_vae_encoder = torch . jit . trace (
baseline_encoder , (
sample_vae_encoder_inputs [ " sample " ] . to ( torch . float32 ) ,
sample_vae_encoder_inputs [ " diagonal_noise " ] . to ( torch . float32 ) ,
sample_vae_encoder_inputs [ " noise " ] . to ( torch . float32 ) ,
sqrt_alphas_cumprod_torch_shape . to ( torch . float32 ) ,
sqrt_one_minus_alphas_cumprod_torch_shape . to ( torch . float32 )
) )
modify_coremltools_torch_frontend_badbmm ( )
coreml_vae_encoder , out_path = _convert_to_coreml (
" vae_encoder " , traced_vae_encoder , sample_vae_encoder_inputs ,
[ " latent_dist " ] , args )
# Set model metadata
coreml_vae_encoder . author = f " Please refer to the Model Card available at huggingface.co/ { args . model_version } "
coreml_vae_encoder . license = " OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license) "
coreml_vae_encoder . version = args . model_version
coreml_vae_encoder . short_description = \
" Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \
" Please refer to https://arxiv.org/abs/2112.10752 for details. "
# Set the input descriptions
coreml_vae_encoder . input_description [ " sample " ] = \
" An image of the correct size to create the latent space with, image2image and in-painting. "
coreml_vae_encoder . input_description [ " diagonal_noise " ] = \
" Latent noise for `DiagonalGaussianDistribution` operation. "
coreml_vae_encoder . input_description [ " noise " ] = \
" Latent noise for use with strength parameter of image2image "
coreml_vae_encoder . input_description [ " sqrt_alphas_cumprod " ] = \
" Precalculated `sqrt_alphas_cumprod` value based on strength and the current schedular ' s alphasCumprod values "
coreml_vae_encoder . input_description [ " sqrt_one_minus_alphas_cumprod " ] = \
" Precalculated `sqrt_one_minus_alphas_cumprod` value based on strength and the current schedular ' s alphasCumprod values "
# Set the output descriptions
coreml_vae_encoder . output_description [
" latent_dist " ] = " The latent embeddings from the unet model from the input image. "
_save_mlpackage ( coreml_vae_encoder , out_path )
logger . info ( f " Saved vae_encoder into { out_path } " )
# Parity check PyTorch vs CoreML
if args . check_output_correctness :
baseline_out = baseline_encoder (
sample = sample_vae_encoder_inputs [ " sample " ] . to ( torch . float32 ) ,
diagonal_noise = sample_vae_encoder_inputs [ " diagonal_noise " ] . to ( torch . float32 ) ,
noise = sample_vae_encoder_inputs [ " noise " ] . to ( torch . float32 ) ,
sqrt_alphas_cumprod = sqrt_alphas_cumprod_torch_shape ,
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod_torch_shape ,
) . numpy ( ) ,
coreml_out = list (
coreml_vae_encoder . predict (
{
" sample " : sample_vae_encoder_inputs [ " sample " ] . numpy ( ) ,
" diagonal_noise " : sample_vae_encoder_inputs [ " diagonal_noise " ] . numpy ( ) ,
" noise " : sample_vae_encoder_inputs [ " noise " ] . numpy ( ) ,
" sqrt_alphas_cumprod " : sqrt_alphas_cumprod_torch_shape . numpy ( ) ,
" sqrt_one_minus_alphas_cumprod " : sqrt_one_minus_alphas_cumprod_torch_shape . numpy ( )
} ) . values ( ) )
report_correctness ( baseline_out [ 0 ] , coreml_out [ 0 ] ,
" vae_encoder baseline PyTorch to baseline CoreML " )
del traced_vae_encoder , pipe . vae . encoder , coreml_vae_encoder
gc . collect ( )
def convert_unet ( pipe , args ) :
def convert_unet ( pipe , args ) :
""" Converts the UNet component of Stable Diffusion
""" Converts the UNet component of Stable Diffusion
"""
"""
@ -802,6 +971,11 @@ def main(args):
convert_vae_decoder ( pipe , args )
convert_vae_decoder ( pipe , args )
logger . info ( " Converted vae_decoder " )
logger . info ( " Converted vae_decoder " )
if args . convert_vae_encoder :
logger . info ( " Converting vae_encoder " )
convert_vae_encoder ( pipe , args )
logger . info ( " Converted vae_encoder " )
if args . convert_unet :
if args . convert_unet :
logger . info ( " Converting unet " )
logger . info ( " Converting unet " )
convert_unet ( pipe , args )
convert_unet ( pipe , args )
@ -835,6 +1009,7 @@ def parser_spec():
# Select which models to export (All are needed for text-to-image pipeline to function)
# Select which models to export (All are needed for text-to-image pipeline to function)
parser . add_argument ( " --convert-text-encoder " , action = " store_true " )
parser . add_argument ( " --convert-text-encoder " , action = " store_true " )
parser . add_argument ( " --convert-vae-decoder " , action = " store_true " )
parser . add_argument ( " --convert-vae-decoder " , action = " store_true " )
parser . add_argument ( " --convert-vae-encoder " , action = " store_true " )
parser . add_argument ( " --convert-unet " , action = " store_true " )
parser . add_argument ( " --convert-unet " , action = " store_true " )
parser . add_argument ( " --convert-safety-checker " , action = " store_true " )
parser . add_argument ( " --convert-safety-checker " , action = " store_true " )
parser . add_argument (
parser . add_argument (