Pass latent size args to unet and safety checker (#123)

pull/125/head
Joshua Park 2 years ago committed by GitHub
parent e4ca3ff86b
commit f0fc326477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -340,7 +340,7 @@ def convert_text_encoder(pipe, args):
def modify_coremltools_torch_frontend_badbmm(): def modify_coremltools_torch_frontend_badbmm():
""" """
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype: Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315 e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
""" """
from coremltools.converters.mil import register_torch_op from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import Builder as mb
@ -490,12 +490,12 @@ def convert_vae_encoder(pipe, args):
(args.latent_h or pipe.unet.config.sample_size) * 8, # H (args.latent_h or pipe.unet.config.sample_size) * 8, # H
(args.latent_w or pipe.unet.config.sample_size) * 8, # w (args.latent_w or pipe.unet.config.sample_size) * 8, # w
) )
noise_shape = ( noise_shape = (
1, # B 1, # B
4, # C 4, # C
pipe.unet.config.sample_size, # H args.latent_h or pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # w args.latent_w or pipe.unet.config.sample_size, # w
) )
float_value_shape = ( float_value_shape = (
@ -533,12 +533,12 @@ def convert_vae_encoder(pipe, args):
moments = self.quant_conv(h) moments = self.quant_conv(h)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise) posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
posteriorSample = posterior.sample() posteriorSample = posterior.sample()
# Add the scaling operation and the latent noise for faster computation # Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample init_latents = 0.18215 * posteriorSample
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod) result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
return result return result
def add_noise( def add_noise(
self, self,
original_samples: torch.FloatTensor, original_samples: torch.FloatTensor,
@ -548,7 +548,7 @@ def convert_vae_encoder(pipe, args):
) -> torch.FloatTensor: ) -> torch.FloatTensor:
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
return noisy_samples return noisy_samples
baseline_encoder = VAEEncoder().eval() baseline_encoder = VAEEncoder().eval()
@ -653,8 +653,8 @@ def convert_unet(pipe, args):
sample_shape = ( sample_shape = (
batch_size, # B batch_size, # B
pipe.unet.config.in_channels, # C pipe.unet.config.in_channels, # C
pipe.unet.config.sample_size, # H args.latent_h or pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # W args.latent_w or pipe.unet.config.sample_size, # W
) )
if not hasattr(pipe, "text_encoder"): if not hasattr(pipe, "text_encoder"):
@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args):
sample_image = np.random.randn( sample_image = np.random.randn(
1, # B 1, # B
pipe.vae.config.sample_size, # H args.latent_h or pipe.vae.config.sample_size, # H
pipe.vae.config.sample_size, # w args.latent_w or pipe.vae.config.sample_size, # w
3 # C 3 # C
).astype(np.float32) ).astype(np.float32)
@ -970,12 +970,12 @@ def main(args):
logger.info("Converting vae_decoder") logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args) convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder") logger.info("Converted vae_decoder")
if args.convert_vae_encoder: if args.convert_vae_encoder:
logger.info("Converting vae_encoder") logger.info("Converting vae_encoder")
convert_vae_encoder(pipe, args) convert_vae_encoder(pipe, args)
logger.info("Converted vae_encoder") logger.info("Converted vae_encoder")
if args.convert_unet: if args.convert_unet:
logger.info("Converting unet") logger.info("Converting unet")
convert_unet(pipe, args) convert_unet(pipe, args)

Loading…
Cancel
Save