You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1084 lines
37 KiB

#
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from python_coreml_stable_diffusion.layer_norm import LayerNormANE
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from enum import Enum
import logging
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# Ensure minimum macOS version requirement is met for this particular model
from coremltools.models.utils import _macos_version
if not _macos_version() >= (13, 1):
logger.warning(
"!!! macOS 13.1 and newer or iOS/iPadOS 16.2 and newer is required for best performance !!!"
)
class AttentionImplementations(Enum):
ORIGINAL = "ORIGINAL"
SPLIT_EINSUM = "SPLIT_EINSUM"
ATTENTION_IMPLEMENTATION_IN_EFFECT = AttentionImplementations.SPLIT_EINSUM
WARN_MSG = \
"This `nn.Module` is intended for Apple Silicon deployment only. " \
"PyTorch-specific optimizations and training is disabled"
class CrossAttention(nn.Module):
""" Apple Silicon friendly version of `diffusers.models.attention.CrossAttention`
"""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Conv2d(query_dim, inner_dim, kernel_size=1, bias=False)
self.to_k = nn.Conv2d(context_dim,
inner_dim,
kernel_size=1,
bias=False)
self.to_v = nn.Conv2d(context_dim,
inner_dim,
kernel_size=1,
bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, query_dim, kernel_size=1, bias=True))
def forward(self, hidden_states, context=None, mask=None):
if self.training:
raise NotImplementedError(WARN_MSG)
batch_size, dim, _, sequence_length = hidden_states.shape
q = self.to_q(hidden_states)
context = context if context is not None else hidden_states
k = self.to_k(context)
v = self.to_v(context)
# Validate mask
if mask is not None:
expected_mask_shape = [batch_size, sequence_length, 1, 1]
if mask.dtype == torch.bool:
mask = mask.logical_not().float() * -1e4
elif mask.dtype == torch.int64:
mask = (1 - mask).float() * -1e4
elif mask.dtype != torch.float32:
raise TypeError(f"Unexpected dtype for mask: {mask.dtype}")
if len(mask.size()) == 2:
mask = mask.unsqueeze(2).unsqueeze(2)
if list(mask.size()) != expected_mask_shape:
raise RuntimeError(
f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}"
)
if ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.ORIGINAL:
# This version of the attention function is recommended for high GPU core count
# devices such as the M1 Max and M1 Ultra
bs = q.size(0)
mh_q = q.view(bs, self.heads, self.dim_head, -1)
mh_k = k.view(bs, self.heads, self.dim_head, -1)
mh_v = v.view(bs, self.heads, self.dim_head, -1)
attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k])
attn_weights.mul_(self.scale)
if mask is not None:
attn_weights = attn_weights + mask
attn_weights = attn_weights.softmax(dim=3)
attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v])
attn = attn.contiguous().view(bs, self.heads * self.dim_head, 1,
-1)
elif ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.SPLIT_EINSUM:
# The split attention and einsum from https://machinelearning.apple.com/research/neural-engine-transformers
# are utilized to build an ANE implementation. This version is marginally slower on the GPU engine and is
# not recommended for Max and Ultra Mac variants
mh_q = [
q[:, head_idx * self.dim_head:(head_idx + 1) *
self.dim_head, :, :] for head_idx in range(self.heads)
] # (bs, dim_head, 1, max_seq_length) * heads
k = k.transpose(1, 3)
mh_k = [
k[:, :, :,
head_idx * self.dim_head:(head_idx + 1) * self.dim_head]
for head_idx in range(self.heads)
] # (bs, max_seq_length, 1, dim_head) * heads
mh_v = [
v[:, head_idx * self.dim_head:(head_idx + 1) *
self.dim_head, :, :] for head_idx in range(self.heads)
] # (bs, dim_head, 1, max_seq_length) * heads
attn_weights = [
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * self.scale
for qi, ki in zip(mh_q, mh_k)
] # (bs, max_seq_length, 1, max_seq_length) * heads
if mask is not None:
for head_idx in range(self.heads):
attn_weights[head_idx] = attn_weights[head_idx] + mask
attn_weights = [
aw.softmax(dim=1) for aw in attn_weights
] # (bs, max_seq_length, 1, max_seq_length) * heads
attn = [
torch.einsum("bkhq,bchk->bchq", wi, vi)
for wi, vi in zip(attn_weights, mh_v)
] # (bs, dim_head, 1, max_seq_length) * heads
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
else:
raise ValueError(ATTENTION_IMPLEMENTATION_IN_EFFECT)
return self.to_out(attn)
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_internal_proj = all(substr in k for substr in ["to_", ".weight"])
is_ff_proj = all(substr in k for substr in ["ff.", ".weight"])
is_temb_proj = all(substr in k for substr in ["time_emb", ".weight"])
is_proj_in = "proj_in.weight" in k
is_proj_out = "proj_out.weight" in k
if is_internal_proj or is_ff_proj or is_temb_proj or is_proj_in or is_proj_out:
if len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
# Note: torch.nn.LayerNorm and ane_transformers.reference.layer_norm.LayerNormANE
# apply scale and bias terms in opposite orders. In order to accurately restore a
# state_dict trained using the former into the the latter, we adjust the bias term
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs):
state_dict[prefix +
"bias"] = state_dict[prefix + "bias"] / state_dict[prefix +
"weight"]
return state_dict
class LayerNormANE(LayerNormANE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_load_state_dict_pre_hook(
correct_for_bias_scale_order_inversion)
# Reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
# (modified, e.g. the attention implementation)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
attn_num_head_channels=1,
cross_attention_dim=768,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers -
1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
))
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels)])
def forward(self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states],
dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
prev_output_channel,
out_channels,
temb_channels,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
add_upsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers -
1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
))
self.resnets = nn.ModuleList(resnets)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels)])
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states],
dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
attn_num_head_channels=1,
cross_attention_dim=768,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
resnets = []
attentions = []
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
))
attentions.append(
SpatialTransformer(
out_channels,
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([Downsample2D(in_channels)])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states, )
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states, )
return hidden_states, output_states
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
))
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states, )
return hidden_states, output_states
class ResnetBlock2D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
temb_channels=512,
groups=32,
groups_out=None,
eps=1e-6,
time_embedding_norm="default",
use_nin_shortcut=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups,
num_channels=in_channels,
eps=eps,
affine=True)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Conv2d(temb_channels,
out_channels,
kernel_size=1)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out,
num_channels=out_channels,
eps=eps,
affine=True)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.nonlinearity = nn.SiLU()
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
hidden_states = x
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + hidden_states)
return out
class Upsample2D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
return self.conv(x)
class Downsample2D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
def forward(self, x):
return self.conv(x)
class SpatialTransformer(nn.Module):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
context_dim=None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
context_dim=context_dim)
for d in range(depth)
])
self.proj_out = nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, hidden_states, context=None):
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.view(batch, channel, 1, height * weight)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.view(batch, channel, height, weight)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, gated_ff=True):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
)
self.ff = FeedForward(dim, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
)
self.norm1 = LayerNormANE(dim)
self.norm2 = LayerNormANE(dim)
self.norm3 = LayerNormANE(dim)
def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states),
context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
GEGLU(dim_in=dim, dim_out=inner_dim), nn.Identity(),
nn.Conv2d(inner_dim,
dim_out if dim_out is not None else dim,
kernel_size=1))
def forward(self, hidden_states):
return self.net(hidden_states)
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Conv2d(dim_in, dim_out * 2, kernel_size=1)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=1)
return hidden_states * F.gelu(gate)
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Conv2d(channel, time_embed_dim, kernel_size=1)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
self.linear_2 = nn.Conv2d(time_embed_dim,
time_embed_dim,
kernel_size=1)
def forward(self, sample):
if len(sample.shape) == 2:
sample = sample.unsqueeze(-1).unsqueeze(-1)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
def get_timestep_embedding(
timesteps,
embedding_dim,
flip_sin_to_cos=False,
downscale_freq_shift=1,
scale=1,
max_period=10000,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent).to(device=timesteps.device)
emb = timesteps[:, None].float() * emb[None, :]
emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
attn_num_head_channels=1,
attention_type="default",
cross_attention_dim=768,
**kwargs,
):
super().__init__()
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(
in_channels // 4, 32)
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
))
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
time_embedding_norm=resnet_time_scale_shift,
))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=4,
out_channels=4,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=(
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
"CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention=False,
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
cross_attention_dim=768,
attention_head_dim=8,
**kwargs,
):
if kwargs.get("dual_cross_attention", None):
raise NotImplementedError
if kwargs.get("num_classs_embeds", None):
raise NotImplementedError
if only_cross_attention:
raise NotImplementedError
if kwargs.get("use_linear_projection", None):
logger.warning("`use_linear_projection=True` is ignored!")
super().__init__()
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2d(in_channels,
block_out_channels[0],
kernel_size=3,
padding=(1, 1))
# time
time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
freq_shift)
timestep_input_dim = block_out_channels[0]
time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.time_proj = time_proj
self.time_embedding = time_embedding
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
resnet_groups=norm_num_groups,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(
i + 1,
len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0],
num_groups=norm_num_groups,
eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0],
out_channels,
3,
padding=1)
def forward(
self,
sample,
timestep,
encoder_hidden_states,
):
# 0. Project (or look-up) time embeddings
t_emb = self.time_proj(timestep)
emb = self.time_embedding(t_emb)
# 1. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample, )
for downsample_block in self.down_blocks:
if hasattr(
downsample_block,
"attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states)
else:
sample, res_samples = downsample_block(hidden_states=sample,
temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample,
emb,
encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(
upsample_block.resnets)]
if hasattr(upsample_block,
"attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return (sample, )
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
cross_attention_dim=None,
downsample_padding=None,
):
down_block_type = down_block_type[7:] if down_block_type.startswith(
"UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
return DownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif down_block_type == "CrossAttnDownBlock2D":
if cross_attention_dim is None:
raise ValueError(
"cross_attention_dim must be specified for CrossAttnDownBlock2D"
)
return CrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith(
"UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
return UpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
raise ValueError(
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
return CrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")