Pass latent size args to unet and safety checker (#123)

pull/125/head
Joshua Park 1 year ago committed by GitHub
parent e4ca3ff86b
commit f0fc326477
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -494,8 +494,8 @@ def convert_vae_encoder(pipe, args):
noise_shape = (
1, # B
4, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # w
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # w
)
float_value_shape = (
@ -653,8 +653,8 @@ def convert_unet(pipe, args):
sample_shape = (
batch_size, # B
pipe.unet.config.in_channels, # C
pipe.unet.config.sample_size, # H
pipe.unet.config.sample_size, # W
args.latent_h or pipe.unet.config.sample_size, # H
args.latent_w or pipe.unet.config.sample_size, # W
)
if not hasattr(pipe, "text_encoder"):
@ -791,8 +791,8 @@ def convert_safety_checker(pipe, args):
sample_image = np.random.randn(
1, # B
pipe.vae.config.sample_size, # H
pipe.vae.config.sample_size, # w
args.latent_h or pipe.vae.config.sample_size, # H
args.latent_w or pipe.vae.config.sample_size, # w
3 # C
).astype(np.float32)

Loading…
Cancel
Save