diff --git a/python_coreml_stable_diffusion/torch2coreml.py b/python_coreml_stable_diffusion/torch2coreml.py index 84e201c..abacf9c 100644 --- a/python_coreml_stable_diffusion/torch2coreml.py +++ b/python_coreml_stable_diffusion/torch2coreml.py @@ -340,7 +340,7 @@ def convert_text_encoder(pipe, args): def modify_coremltools_torch_frontend_badbmm(): """ Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype: - e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315 + e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315 """ from coremltools.converters.mil import register_torch_op from coremltools.converters.mil.mil import Builder as mb @@ -490,12 +490,12 @@ def convert_vae_encoder(pipe, args): (args.latent_h or pipe.unet.config.sample_size) * 8, # H (args.latent_w or pipe.unet.config.sample_size) * 8, # w ) - + noise_shape = ( 1, # B 4, # C - pipe.unet.config.sample_size, # H - pipe.unet.config.sample_size, # w + args.latent_h or pipe.unet.config.sample_size, # H + args.latent_w or pipe.unet.config.sample_size, # w ) float_value_shape = ( @@ -533,12 +533,12 @@ def convert_vae_encoder(pipe, args): moments = self.quant_conv(h) posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise) posteriorSample = posterior.sample() - + # Add the scaling operation and the latent noise for faster computation init_latents = 0.18215 * posteriorSample result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod) return result - + def add_noise( self, original_samples: torch.FloatTensor, @@ -548,7 +548,7 @@ def convert_vae_encoder(pipe, args): ) -> torch.FloatTensor: noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise return noisy_samples - + baseline_encoder = VAEEncoder().eval() @@ -653,8 +653,8 @@ def convert_unet(pipe, args): sample_shape = ( batch_size, # B pipe.unet.config.in_channels, # C - pipe.unet.config.sample_size, # H - pipe.unet.config.sample_size, # W + args.latent_h or pipe.unet.config.sample_size, # H + args.latent_w or pipe.unet.config.sample_size, # W ) if not hasattr(pipe, "text_encoder"): @@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args): sample_image = np.random.randn( 1, # B - pipe.vae.config.sample_size, # H - pipe.vae.config.sample_size, # w + args.latent_h or pipe.vae.config.sample_size, # H + args.latent_w or pipe.vae.config.sample_size, # w 3 # C ).astype(np.float32) @@ -970,12 +970,12 @@ def main(args): logger.info("Converting vae_decoder") convert_vae_decoder(pipe, args) logger.info("Converted vae_decoder") - + if args.convert_vae_encoder: logger.info("Converting vae_encoder") convert_vae_encoder(pipe, args) logger.info("Converted vae_encoder") - + if args.convert_unet: logger.info("Converting unet") convert_unet(pipe, args)