|
|
|
@ -494,8 +494,8 @@ def convert_vae_encoder(pipe, args):
|
|
|
|
|
noise_shape = (
|
|
|
|
|
1, # B
|
|
|
|
|
4, # C
|
|
|
|
|
pipe.unet.config.sample_size, # H
|
|
|
|
|
pipe.unet.config.sample_size, # w
|
|
|
|
|
args.latent_h or pipe.unet.config.sample_size, # H
|
|
|
|
|
args.latent_w or pipe.unet.config.sample_size, # w
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
float_value_shape = (
|
|
|
|
@ -653,8 +653,8 @@ def convert_unet(pipe, args):
|
|
|
|
|
sample_shape = (
|
|
|
|
|
batch_size, # B
|
|
|
|
|
pipe.unet.config.in_channels, # C
|
|
|
|
|
pipe.unet.config.sample_size, # H
|
|
|
|
|
pipe.unet.config.sample_size, # W
|
|
|
|
|
args.latent_h or pipe.unet.config.sample_size, # H
|
|
|
|
|
args.latent_w or pipe.unet.config.sample_size, # W
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not hasattr(pipe, "text_encoder"):
|
|
|
|
@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args):
|
|
|
|
|
|
|
|
|
|
sample_image = np.random.randn(
|
|
|
|
|
1, # B
|
|
|
|
|
pipe.vae.config.sample_size, # H
|
|
|
|
|
pipe.vae.config.sample_size, # w
|
|
|
|
|
args.latent_h or pipe.vae.config.sample_size, # H
|
|
|
|
|
args.latent_w or pipe.vae.config.sample_size, # w
|
|
|
|
|
3 # C
|
|
|
|
|
).astype(np.float32)
|
|
|
|
|
|
|
|
|
|