|
|
@ -340,7 +340,7 @@ def convert_text_encoder(pipe, args):
|
|
|
|
def modify_coremltools_torch_frontend_badbmm():
|
|
|
|
def modify_coremltools_torch_frontend_badbmm():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
|
|
|
|
Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype:
|
|
|
|
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
|
|
|
|
e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
from coremltools.converters.mil import register_torch_op
|
|
|
|
from coremltools.converters.mil import register_torch_op
|
|
|
|
from coremltools.converters.mil.mil import Builder as mb
|
|
|
|
from coremltools.converters.mil.mil import Builder as mb
|
|
|
@ -490,12 +490,12 @@ def convert_vae_encoder(pipe, args):
|
|
|
|
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
|
|
|
|
(args.latent_h or pipe.unet.config.sample_size) * 8, # H
|
|
|
|
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
|
|
|
|
(args.latent_w or pipe.unet.config.sample_size) * 8, # w
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
noise_shape = (
|
|
|
|
noise_shape = (
|
|
|
|
1, # B
|
|
|
|
1, # B
|
|
|
|
4, # C
|
|
|
|
4, # C
|
|
|
|
pipe.unet.config.sample_size, # H
|
|
|
|
args.latent_h or pipe.unet.config.sample_size, # H
|
|
|
|
pipe.unet.config.sample_size, # w
|
|
|
|
args.latent_w or pipe.unet.config.sample_size, # w
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
float_value_shape = (
|
|
|
|
float_value_shape = (
|
|
|
@ -533,12 +533,12 @@ def convert_vae_encoder(pipe, args):
|
|
|
|
moments = self.quant_conv(h)
|
|
|
|
moments = self.quant_conv(h)
|
|
|
|
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
|
|
|
|
posterior = CoreMLDiagonalGaussianDistribution(moments, diagonal_noise)
|
|
|
|
posteriorSample = posterior.sample()
|
|
|
|
posteriorSample = posterior.sample()
|
|
|
|
|
|
|
|
|
|
|
|
# Add the scaling operation and the latent noise for faster computation
|
|
|
|
# Add the scaling operation and the latent noise for faster computation
|
|
|
|
init_latents = 0.18215 * posteriorSample
|
|
|
|
init_latents = 0.18215 * posteriorSample
|
|
|
|
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
|
|
|
|
result = self.add_noise(init_latents, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def add_noise(
|
|
|
|
def add_noise(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
original_samples: torch.FloatTensor,
|
|
|
|
original_samples: torch.FloatTensor,
|
|
|
@ -548,7 +548,7 @@ def convert_vae_encoder(pipe, args):
|
|
|
|
) -> torch.FloatTensor:
|
|
|
|
) -> torch.FloatTensor:
|
|
|
|
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
|
|
|
|
noisy_samples = sqrt_alphas_cumprod * original_samples + sqrt_one_minus_alphas_cumprod * noise
|
|
|
|
return noisy_samples
|
|
|
|
return noisy_samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
baseline_encoder = VAEEncoder().eval()
|
|
|
|
baseline_encoder = VAEEncoder().eval()
|
|
|
|
|
|
|
|
|
|
|
@ -653,8 +653,8 @@ def convert_unet(pipe, args):
|
|
|
|
sample_shape = (
|
|
|
|
sample_shape = (
|
|
|
|
batch_size, # B
|
|
|
|
batch_size, # B
|
|
|
|
pipe.unet.config.in_channels, # C
|
|
|
|
pipe.unet.config.in_channels, # C
|
|
|
|
pipe.unet.config.sample_size, # H
|
|
|
|
args.latent_h or pipe.unet.config.sample_size, # H
|
|
|
|
pipe.unet.config.sample_size, # W
|
|
|
|
args.latent_w or pipe.unet.config.sample_size, # W
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not hasattr(pipe, "text_encoder"):
|
|
|
|
if not hasattr(pipe, "text_encoder"):
|
|
|
@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args):
|
|
|
|
|
|
|
|
|
|
|
|
sample_image = np.random.randn(
|
|
|
|
sample_image = np.random.randn(
|
|
|
|
1, # B
|
|
|
|
1, # B
|
|
|
|
pipe.vae.config.sample_size, # H
|
|
|
|
args.latent_h or pipe.vae.config.sample_size, # H
|
|
|
|
pipe.vae.config.sample_size, # w
|
|
|
|
args.latent_w or pipe.vae.config.sample_size, # w
|
|
|
|
3 # C
|
|
|
|
3 # C
|
|
|
|
).astype(np.float32)
|
|
|
|
).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
@ -970,12 +970,12 @@ def main(args):
|
|
|
|
logger.info("Converting vae_decoder")
|
|
|
|
logger.info("Converting vae_decoder")
|
|
|
|
convert_vae_decoder(pipe, args)
|
|
|
|
convert_vae_decoder(pipe, args)
|
|
|
|
logger.info("Converted vae_decoder")
|
|
|
|
logger.info("Converted vae_decoder")
|
|
|
|
|
|
|
|
|
|
|
|
if args.convert_vae_encoder:
|
|
|
|
if args.convert_vae_encoder:
|
|
|
|
logger.info("Converting vae_encoder")
|
|
|
|
logger.info("Converting vae_encoder")
|
|
|
|
convert_vae_encoder(pipe, args)
|
|
|
|
convert_vae_encoder(pipe, args)
|
|
|
|
logger.info("Converted vae_encoder")
|
|
|
|
logger.info("Converted vae_encoder")
|
|
|
|
|
|
|
|
|
|
|
|
if args.convert_unet:
|
|
|
|
if args.convert_unet:
|
|
|
|
logger.info("Converting unet")
|
|
|
|
logger.info("Converting unet")
|
|
|
|
convert_unet(pipe, args)
|
|
|
|
convert_unet(pipe, args)
|
|
|
|