Pass latent size args to unet and safety checker (#123)

pull/125/head
Joshua Park 1 year ago committed by GitHub
parent e4ca3ff86b
commit f0fc326477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -340,7 +340,7 @@ def convert_text_encoder(pipe, args):
def modify_coremltools_torch_frontend_badbmm():
"""
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
"""
from coremltools.converters.mil import register_torch_op
from coremltools.converters.mil.mil import Builder as mb
@ -490,12 +490,12 @@ def convert_vae_encoder(pipe, args):
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
)
noise_shape = (
1, # B
4, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # w
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # w
)
float_value_shape = (
@ -533,12 +533,12 @@ def convert_vae_encoder(pipe, args):
moments = self.quant_conv(h)
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
posteriorSample = posterior.sample()
# Add the scaling operation and the latent noise for faster computation
init_latents = 0.18215 * posteriorSample
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
return result
def add_noise(
self,
original_samples: torch.FloatTensor,
@ -548,7 +548,7 @@ def convert_vae_encoder(pipe, args):
) -> torch.FloatTensor:
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
return noisy_samples
baseline_encoder = VAEEncoder().eval()
@ -653,8 +653,8 @@ def convert_unet(pipe, args):
sample_shape = (
batch_size, # B
pipe.unet.config.in_channels, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # W
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # W
)
if not hasattr(pipe, "text_encoder"):
@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args):
sample_image = np.random.randn(
1, # B
pipe.vae.config.sample_size, # H
pipe.vae.config.sample_size, # w
args.latent_h or pipe.vae.config.sample_size, # H
args.latent_w or pipe.vae.config.sample_size, # w
3 # C
).astype(np.float32)
@ -970,12 +970,12 @@ def main(args):
logger.info("Converting vae_decoder")
convert_vae_decoder(pipe, args)
logger.info("Converted vae_decoder")
if args.convert_vae_encoder:
logger.info("Converting vae_encoder")
convert_vae_encoder(pipe, args)
logger.info("Converted vae_encoder")
if args.convert_unet:
logger.info("Converting unet")
convert_unet(pipe, args)

Loading…
Cancel
Save