diff --git a/python_coreml_stable_diffusion/unet.py b/python_coreml_stable_diffusion/unet.py index 581cfcd..f75c7f2 100644 --- a/python_coreml_stable_diffusion/unet.py +++ b/python_coreml_stable_diffusion/unet.py @@ -376,7 +376,7 @@ class CrossAttnDownBlock2D(nn.Module): self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsamplers = nn.ModuleList([Downsample2D(in_channels)]) + self.downsamplers = nn.ModuleList([Downsample2D(out_channels)]) else: self.downsamplers = None