You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1084 lines
37 KiB
1084 lines
37 KiB
#
|
|
# For licensing see accompanying LICENSE.md file.
|
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
|
#
|
|
|
|
from python_coreml_stable_diffusion.layer_norm import LayerNormANE
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers import ModelMixin
|
|
|
|
from enum import Enum
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# Ensure minimum macOS version requirement is met for this particular model
|
|
from coremltools.models.utils import _macos_version
|
|
if not _macos_version() >= (13, 1):
|
|
logger.warning(
|
|
"!!! macOS 13.1 and newer or iOS/iPadOS 16.2 and newer is required for best performance !!!"
|
|
)
|
|
|
|
|
|
class AttentionImplementations(Enum):
|
|
ORIGINAL = "ORIGINAL"
|
|
SPLIT_EINSUM = "SPLIT_EINSUM"
|
|
|
|
|
|
ATTENTION_IMPLEMENTATION_IN_EFFECT = AttentionImplementations.SPLIT_EINSUM
|
|
|
|
WARN_MSG = \
|
|
"This `nn.Module` is intended for Apple Silicon deployment only. " \
|
|
"PyTorch-specific optimizations and training is disabled"
|
|
|
|
class CrossAttention(nn.Module):
|
|
""" Apple Silicon friendly version of `diffusers.models.attention.CrossAttention`
|
|
"""
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = context_dim if context_dim is not None else query_dim
|
|
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
|
|
self.to_q = nn.Conv2d(query_dim, inner_dim, kernel_size=1, bias=False)
|
|
self.to_k = nn.Conv2d(context_dim,
|
|
inner_dim,
|
|
kernel_size=1,
|
|
bias=False)
|
|
self.to_v = nn.Conv2d(context_dim,
|
|
inner_dim,
|
|
kernel_size=1,
|
|
bias=False)
|
|
self.to_out = nn.Sequential(
|
|
nn.Conv2d(inner_dim, query_dim, kernel_size=1, bias=True))
|
|
|
|
def forward(self, hidden_states, context=None, mask=None):
|
|
if self.training:
|
|
raise NotImplementedError(WARN_MSG)
|
|
|
|
batch_size, dim, _, sequence_length = hidden_states.shape
|
|
|
|
q = self.to_q(hidden_states)
|
|
context = context if context is not None else hidden_states
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
# Validate mask
|
|
if mask is not None:
|
|
expected_mask_shape = [batch_size, sequence_length, 1, 1]
|
|
if mask.dtype == torch.bool:
|
|
mask = mask.logical_not().float() * -1e4
|
|
elif mask.dtype == torch.int64:
|
|
mask = (1 - mask).float() * -1e4
|
|
elif mask.dtype != torch.float32:
|
|
raise TypeError(f"Unexpected dtype for mask: {mask.dtype}")
|
|
|
|
if len(mask.size()) == 2:
|
|
mask = mask.unsqueeze(2).unsqueeze(2)
|
|
|
|
if list(mask.size()) != expected_mask_shape:
|
|
raise RuntimeError(
|
|
f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}"
|
|
)
|
|
|
|
if ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.ORIGINAL:
|
|
# This version of the attention function is recommended for high GPU core count
|
|
# devices such as the M1 Max and M1 Ultra
|
|
bs = q.size(0)
|
|
mh_q = q.view(bs, self.heads, self.dim_head, -1)
|
|
mh_k = k.view(bs, self.heads, self.dim_head, -1)
|
|
mh_v = v.view(bs, self.heads, self.dim_head, -1)
|
|
|
|
attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k])
|
|
attn_weights.mul_(self.scale)
|
|
|
|
if mask is not None:
|
|
attn_weights = attn_weights + mask
|
|
|
|
attn_weights = attn_weights.softmax(dim=3)
|
|
|
|
attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v])
|
|
attn = attn.contiguous().view(bs, self.heads * self.dim_head, 1,
|
|
-1)
|
|
|
|
elif ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.SPLIT_EINSUM:
|
|
# The split attention and einsum from https://machinelearning.apple.com/research/neural-engine-transformers
|
|
# are utilized to build an ANE implementation. This version is marginally slower on the GPU engine and is
|
|
# not recommended for Max and Ultra Mac variants
|
|
mh_q = [
|
|
q[:, head_idx * self.dim_head:(head_idx + 1) *
|
|
self.dim_head, :, :] for head_idx in range(self.heads)
|
|
] # (bs, dim_head, 1, max_seq_length) * heads
|
|
|
|
k = k.transpose(1, 3)
|
|
mh_k = [
|
|
k[:, :, :,
|
|
head_idx * self.dim_head:(head_idx + 1) * self.dim_head]
|
|
for head_idx in range(self.heads)
|
|
] # (bs, max_seq_length, 1, dim_head) * heads
|
|
|
|
mh_v = [
|
|
v[:, head_idx * self.dim_head:(head_idx + 1) *
|
|
self.dim_head, :, :] for head_idx in range(self.heads)
|
|
] # (bs, dim_head, 1, max_seq_length) * heads
|
|
|
|
attn_weights = [
|
|
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * self.scale
|
|
for qi, ki in zip(mh_q, mh_k)
|
|
] # (bs, max_seq_length, 1, max_seq_length) * heads
|
|
|
|
if mask is not None:
|
|
for head_idx in range(self.heads):
|
|
attn_weights[head_idx] = attn_weights[head_idx] + mask
|
|
|
|
attn_weights = [
|
|
aw.softmax(dim=1) for aw in attn_weights
|
|
] # (bs, max_seq_length, 1, max_seq_length) * heads
|
|
attn = [
|
|
torch.einsum("bkhq,bchk->bchq", wi, vi)
|
|
for wi, vi in zip(attn_weights, mh_v)
|
|
] # (bs, dim_head, 1, max_seq_length) * heads
|
|
|
|
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
|
|
|
|
else:
|
|
raise ValueError(ATTENTION_IMPLEMENTATION_IN_EFFECT)
|
|
|
|
return self.to_out(attn)
|
|
|
|
|
|
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
|
|
missing_keys, unexpected_keys, error_msgs):
|
|
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
|
|
"""
|
|
for k in state_dict:
|
|
is_internal_proj = all(substr in k for substr in ["to_", ".weight"])
|
|
is_ff_proj = all(substr in k for substr in ["ff.", ".weight"])
|
|
is_temb_proj = all(substr in k for substr in ["time_emb", ".weight"])
|
|
is_proj_in = "proj_in.weight" in k
|
|
is_proj_out = "proj_out.weight" in k
|
|
|
|
if is_internal_proj or is_ff_proj or is_temb_proj or is_proj_in or is_proj_out:
|
|
if len(state_dict[k].shape) == 2:
|
|
state_dict[k] = state_dict[k][:, :, None, None]
|
|
|
|
# Note: torch.nn.LayerNorm and ane_transformers.reference.layer_norm.LayerNormANE
|
|
# apply scale and bias terms in opposite orders. In order to accurately restore a
|
|
# state_dict trained using the former into the the latter, we adjust the bias term
|
|
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
|
|
strict, missing_keys,
|
|
unexpected_keys, error_msgs):
|
|
state_dict[prefix +
|
|
"bias"] = state_dict[prefix + "bias"] / state_dict[prefix +
|
|
"weight"]
|
|
return state_dict
|
|
|
|
|
|
class LayerNormANE(LayerNormANE):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._register_load_state_dict_pre_hook(
|
|
correct_for_bias_scale_order_inversion)
|
|
|
|
|
|
# Reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
|
# (modified, e.g. the attention implementation)
|
|
class CrossAttnUpBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
prev_output_channel,
|
|
temb_channels,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
attn_num_head_channels=1,
|
|
cross_attention_dim=768,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
downsample_padding=1,
|
|
add_upsample=True,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
attentions = []
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers -
|
|
1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
))
|
|
attentions.append(
|
|
SpatialTransformer(
|
|
out_channels,
|
|
attn_num_head_channels,
|
|
out_channels // attn_num_head_channels,
|
|
depth=1,
|
|
context_dim=cross_attention_dim,
|
|
))
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
self.upsamplers = None
|
|
if add_upsample:
|
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels)])
|
|
|
|
def forward(self,
|
|
hidden_states,
|
|
res_hidden_states_tuple,
|
|
temb=None,
|
|
encoder_hidden_states=None):
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states],
|
|
dim=1)
|
|
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UpBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
prev_output_channel,
|
|
out_channels,
|
|
temb_channels,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
add_upsample=True,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
|
|
for i in range(num_layers):
|
|
res_skip_channels = in_channels if (i == num_layers -
|
|
1) else out_channels
|
|
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
|
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=resnet_in_channels + res_skip_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
))
|
|
|
|
self.resnets = nn.ModuleList(resnets)
|
|
self.upsamplers = None
|
|
if add_upsample:
|
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels)])
|
|
|
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
|
for resnet in self.resnets:
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states],
|
|
dim=1)
|
|
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CrossAttnDownBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
attn_num_head_channels=1,
|
|
cross_attention_dim=768,
|
|
attention_type="default",
|
|
output_scale_factor=1.0,
|
|
downsample_padding=1,
|
|
add_downsample=True,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
attentions = []
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
))
|
|
attentions.append(
|
|
SpatialTransformer(
|
|
out_channels,
|
|
attn_num_head_channels,
|
|
out_channels // attn_num_head_channels,
|
|
depth=1,
|
|
context_dim=cross_attention_dim,
|
|
))
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
if add_downsample:
|
|
self.downsamplers = nn.ModuleList([Downsample2D(out_channels)])
|
|
else:
|
|
self.downsamplers = None
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
output_states = ()
|
|
|
|
for resnet, attn in zip(self.resnets, self.attentions):
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
output_states += (hidden_states, )
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
|
|
output_states += (hidden_states, )
|
|
|
|
return hidden_states, output_states
|
|
|
|
|
|
class DownBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
):
|
|
super().__init__()
|
|
resnets = []
|
|
|
|
for i in range(num_layers):
|
|
in_channels = in_channels if i == 0 else out_channels
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
))
|
|
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None):
|
|
output_states = ()
|
|
|
|
for resnet in self.resnets:
|
|
hidden_states = resnet(hidden_states, temb)
|
|
output_states += (hidden_states, )
|
|
|
|
return hidden_states, output_states
|
|
|
|
|
|
class ResnetBlock2D(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
in_channels,
|
|
out_channels=None,
|
|
conv_shortcut=False,
|
|
temb_channels=512,
|
|
groups=32,
|
|
groups_out=None,
|
|
eps=1e-6,
|
|
time_embedding_norm="default",
|
|
use_nin_shortcut=None,
|
|
):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = in_channels if out_channels is None else out_channels
|
|
self.use_conv_shortcut = conv_shortcut
|
|
self.time_embedding_norm = time_embedding_norm
|
|
|
|
if groups_out is None:
|
|
groups_out = groups
|
|
|
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups,
|
|
num_channels=in_channels,
|
|
eps=eps,
|
|
affine=True)
|
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
if temb_channels is not None:
|
|
self.time_emb_proj = torch.nn.Conv2d(temb_channels,
|
|
out_channels,
|
|
kernel_size=1)
|
|
else:
|
|
self.time_emb_proj = None
|
|
|
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out,
|
|
num_channels=out_channels,
|
|
eps=eps,
|
|
affine=True)
|
|
self.conv2 = torch.nn.Conv2d(out_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
self.nonlinearity = nn.SiLU()
|
|
|
|
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
|
|
|
self.conv_shortcut = None
|
|
if self.use_nin_shortcut:
|
|
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
def forward(self, x, temb):
|
|
hidden_states = x
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
hidden_states = self.conv1(hidden_states)
|
|
|
|
if temb is not None:
|
|
temb = self.time_emb_proj(self.nonlinearity(temb))
|
|
hidden_states = hidden_states + temb
|
|
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states = self.nonlinearity(hidden_states)
|
|
hidden_states = self.conv2(hidden_states)
|
|
|
|
if self.conv_shortcut is not None:
|
|
x = self.conv_shortcut(x)
|
|
|
|
out = (x + hidden_states)
|
|
|
|
return out
|
|
|
|
|
|
class Upsample2D(nn.Module):
|
|
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
return self.conv(x)
|
|
|
|
|
|
class Downsample2D(nn.Module):
|
|
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
|
|
class SpatialTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
context_dim=None,
|
|
):
|
|
super().__init__()
|
|
self.n_heads = n_heads
|
|
self.d_head = d_head
|
|
self.in_channels = in_channels
|
|
inner_dim = n_heads * d_head
|
|
self.norm = torch.nn.GroupNorm(num_groups=32,
|
|
num_channels=in_channels,
|
|
eps=1e-6,
|
|
affine=True)
|
|
|
|
self.proj_in = nn.Conv2d(in_channels,
|
|
inner_dim,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
self.transformer_blocks = nn.ModuleList([
|
|
BasicTransformerBlock(inner_dim,
|
|
n_heads,
|
|
d_head,
|
|
context_dim=context_dim)
|
|
for d in range(depth)
|
|
])
|
|
|
|
self.proj_out = nn.Conv2d(inner_dim,
|
|
in_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0)
|
|
|
|
def forward(self, hidden_states, context=None):
|
|
batch, channel, height, weight = hidden_states.shape
|
|
residual = hidden_states
|
|
hidden_states = self.norm(hidden_states)
|
|
hidden_states = self.proj_in(hidden_states)
|
|
hidden_states = hidden_states.view(batch, channel, 1, height * weight)
|
|
for block in self.transformer_blocks:
|
|
hidden_states = block(hidden_states, context=context)
|
|
hidden_states = hidden_states.view(batch, channel, height, weight)
|
|
hidden_states = self.proj_out(hidden_states)
|
|
return hidden_states + residual
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
|
|
def __init__(self, dim, n_heads, d_head, context_dim=None, gated_ff=True):
|
|
super().__init__()
|
|
self.attn1 = CrossAttention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
)
|
|
self.ff = FeedForward(dim, glu=gated_ff)
|
|
self.attn2 = CrossAttention(
|
|
query_dim=dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
)
|
|
self.norm1 = LayerNormANE(dim)
|
|
self.norm2 = LayerNormANE(dim)
|
|
self.norm3 = LayerNormANE(dim)
|
|
|
|
def forward(self, hidden_states, context=None):
|
|
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
|
hidden_states = self.attn2(self.norm2(hidden_states),
|
|
context=context) + hidden_states
|
|
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=False):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
self.net = nn.Sequential(
|
|
GEGLU(dim_in=dim, dim_out=inner_dim), nn.Identity(),
|
|
nn.Conv2d(inner_dim,
|
|
dim_out if dim_out is not None else dim,
|
|
kernel_size=1))
|
|
|
|
def forward(self, hidden_states):
|
|
return self.net(hidden_states)
|
|
|
|
|
|
class GEGLU(nn.Module):
|
|
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Conv2d(dim_in, dim_out * 2, kernel_size=1)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=1)
|
|
return hidden_states * F.gelu(gate)
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
|
|
def __init__(self, channel, time_embed_dim, act_fn="silu"):
|
|
super().__init__()
|
|
|
|
self.linear_1 = nn.Conv2d(channel, time_embed_dim, kernel_size=1)
|
|
self.act = None
|
|
if act_fn == "silu":
|
|
self.act = nn.SiLU()
|
|
self.linear_2 = nn.Conv2d(time_embed_dim,
|
|
time_embed_dim,
|
|
kernel_size=1)
|
|
|
|
def forward(self, sample):
|
|
if len(sample.shape) == 2:
|
|
sample = sample.unsqueeze(-1).unsqueeze(-1)
|
|
sample = self.linear_1(sample)
|
|
|
|
if self.act is not None:
|
|
sample = self.act(sample)
|
|
|
|
sample = self.linear_2(sample)
|
|
return sample
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
|
|
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.downscale_freq_shift = downscale_freq_shift
|
|
|
|
def forward(self, timesteps):
|
|
t_emb = get_timestep_embedding(
|
|
timesteps,
|
|
self.num_channels,
|
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
downscale_freq_shift=self.downscale_freq_shift,
|
|
)
|
|
return t_emb
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps,
|
|
embedding_dim,
|
|
flip_sin_to_cos=False,
|
|
downscale_freq_shift=1,
|
|
scale=1,
|
|
max_period=10000,
|
|
):
|
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * torch.arange(
|
|
start=0, end=half_dim, dtype=torch.float32)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = torch.exp(exponent).to(device=timesteps.device)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
emb = scale * emb
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
if flip_sin_to_cos:
|
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
class UNetMidBlock2DCrossAttn(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
temb_channels,
|
|
num_layers=1,
|
|
resnet_eps=1e-6,
|
|
resnet_time_scale_shift="default",
|
|
resnet_act_fn="swish",
|
|
resnet_groups=32,
|
|
attn_num_head_channels=1,
|
|
attention_type="default",
|
|
cross_attention_dim=768,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.attention_type = attention_type
|
|
self.attn_num_head_channels = attn_num_head_channels
|
|
resnet_groups = resnet_groups if resnet_groups is not None else min(
|
|
in_channels // 4, 32)
|
|
|
|
resnets = [
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
)
|
|
]
|
|
attentions = []
|
|
|
|
for _ in range(num_layers):
|
|
attentions.append(
|
|
SpatialTransformer(
|
|
in_channels,
|
|
attn_num_head_channels,
|
|
in_channels // attn_num_head_channels,
|
|
depth=1,
|
|
context_dim=cross_attention_dim,
|
|
))
|
|
resnets.append(
|
|
ResnetBlock2D(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels,
|
|
temb_channels=temb_channels,
|
|
eps=resnet_eps,
|
|
groups=resnet_groups,
|
|
time_embedding_norm=resnet_time_scale_shift,
|
|
))
|
|
|
|
self.attentions = nn.ModuleList(attentions)
|
|
self.resnets = nn.ModuleList(resnets)
|
|
|
|
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
|
hidden_states = self.resnets[0](hidden_states, temb)
|
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
hidden_states = attn(hidden_states, encoder_hidden_states)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
sample_size=None,
|
|
in_channels=4,
|
|
out_channels=4,
|
|
center_input_sample=False,
|
|
flip_sin_to_cos=True,
|
|
freq_shift=0,
|
|
down_block_types=(
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"CrossAttnDownBlock2D",
|
|
"DownBlock2D",
|
|
),
|
|
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D",
|
|
"CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
|
only_cross_attention=False,
|
|
block_out_channels=(320, 640, 1280, 1280),
|
|
layers_per_block=2,
|
|
downsample_padding=1,
|
|
mid_block_scale_factor=1,
|
|
act_fn="silu",
|
|
norm_num_groups=32,
|
|
norm_eps=1e-5,
|
|
cross_attention_dim=768,
|
|
attention_head_dim=8,
|
|
**kwargs,
|
|
):
|
|
if kwargs.get("dual_cross_attention", None):
|
|
raise NotImplementedError
|
|
if kwargs.get("num_classs_embeds", None):
|
|
raise NotImplementedError
|
|
if only_cross_attention:
|
|
raise NotImplementedError
|
|
if kwargs.get("use_linear_projection", None):
|
|
logger.warning("`use_linear_projection=True` is ignored!")
|
|
|
|
super().__init__()
|
|
self._register_load_state_dict_pre_hook(linear_to_conv2d_map)
|
|
|
|
self.sample_size = sample_size
|
|
time_embed_dim = block_out_channels[0] * 4
|
|
|
|
# input
|
|
self.conv_in = nn.Conv2d(in_channels,
|
|
block_out_channels[0],
|
|
kernel_size=3,
|
|
padding=(1, 1))
|
|
|
|
# time
|
|
time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos,
|
|
freq_shift)
|
|
timestep_input_dim = block_out_channels[0]
|
|
time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
|
|
|
self.time_proj = time_proj
|
|
self.time_embedding = time_embedding
|
|
|
|
self.down_blocks = nn.ModuleList([])
|
|
self.mid_block = None
|
|
self.up_blocks = nn.ModuleList([])
|
|
|
|
if isinstance(only_cross_attention, bool):
|
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
|
|
|
if isinstance(attention_head_dim, int):
|
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
|
|
|
# down
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
down_block = get_down_block(
|
|
down_block_type,
|
|
num_layers=layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
temb_channels=time_embed_dim,
|
|
resnet_eps=norm_eps,
|
|
resnet_act_fn=act_fn,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attention_head_dim[i],
|
|
downsample_padding=downsample_padding,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
# mid
|
|
self.mid_block = UNetMidBlock2DCrossAttn(
|
|
in_channels=block_out_channels[-1],
|
|
temb_channels=time_embed_dim,
|
|
resnet_eps=norm_eps,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=mid_block_scale_factor,
|
|
resnet_time_scale_shift="default",
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attention_head_dim[i],
|
|
resnet_groups=norm_num_groups,
|
|
)
|
|
|
|
# up
|
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
|
output_channel = reversed_block_out_channels[0]
|
|
for i, up_block_type in enumerate(up_block_types):
|
|
prev_output_channel = output_channel
|
|
output_channel = reversed_block_out_channels[i]
|
|
input_channel = reversed_block_out_channels[min(
|
|
i + 1,
|
|
len(block_out_channels) - 1)]
|
|
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
|
|
up_block = get_up_block(
|
|
up_block_type,
|
|
num_layers=layers_per_block + 1,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=time_embed_dim,
|
|
add_upsample=not is_final_block,
|
|
resnet_eps=norm_eps,
|
|
resnet_act_fn=act_fn,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=reversed_attention_head_dim[i],
|
|
)
|
|
self.up_blocks.append(up_block)
|
|
prev_output_channel = output_channel
|
|
|
|
# out
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0],
|
|
num_groups=norm_num_groups,
|
|
eps=norm_eps)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = nn.Conv2d(block_out_channels[0],
|
|
out_channels,
|
|
3,
|
|
padding=1)
|
|
|
|
def forward(
|
|
self,
|
|
sample,
|
|
timestep,
|
|
encoder_hidden_states,
|
|
):
|
|
# 0. Project (or look-up) time embeddings
|
|
t_emb = self.time_proj(timestep)
|
|
emb = self.time_embedding(t_emb)
|
|
|
|
# 1. center input if necessary
|
|
if self.config.center_input_sample:
|
|
sample = 2 * sample - 1.0
|
|
|
|
# 2. pre-process
|
|
sample = self.conv_in(sample)
|
|
|
|
# 3. down
|
|
down_block_res_samples = (sample, )
|
|
for downsample_block in self.down_blocks:
|
|
if hasattr(
|
|
downsample_block,
|
|
"attentions") and downsample_block.attentions is not None:
|
|
sample, res_samples = downsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
encoder_hidden_states=encoder_hidden_states)
|
|
else:
|
|
sample, res_samples = downsample_block(hidden_states=sample,
|
|
temb=emb)
|
|
|
|
down_block_res_samples += res_samples
|
|
|
|
# 4. mid
|
|
sample = self.mid_block(sample,
|
|
emb,
|
|
encoder_hidden_states=encoder_hidden_states)
|
|
|
|
# 5. up
|
|
for upsample_block in self.up_blocks:
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
|
down_block_res_samples = down_block_res_samples[:-len(
|
|
upsample_block.resnets)]
|
|
|
|
if hasattr(upsample_block,
|
|
"attentions") and upsample_block.attentions is not None:
|
|
sample = upsample_block(
|
|
hidden_states=sample,
|
|
temb=emb,
|
|
res_hidden_states_tuple=res_samples,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
)
|
|
else:
|
|
sample = upsample_block(hidden_states=sample,
|
|
temb=emb,
|
|
res_hidden_states_tuple=res_samples)
|
|
|
|
# 6. post-process
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return (sample, )
|
|
|
|
|
|
def get_down_block(
|
|
down_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
temb_channels,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
cross_attention_dim=None,
|
|
downsample_padding=None,
|
|
):
|
|
down_block_type = down_block_type[7:] if down_block_type.startswith(
|
|
"UNetRes") else down_block_type
|
|
if down_block_type == "DownBlock2D":
|
|
return DownBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
)
|
|
elif down_block_type == "CrossAttnDownBlock2D":
|
|
if cross_attention_dim is None:
|
|
raise ValueError(
|
|
"cross_attention_dim must be specified for CrossAttnDownBlock2D"
|
|
)
|
|
return CrossAttnDownBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
temb_channels=temb_channels,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
downsample_padding=downsample_padding,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attn_num_head_channels,
|
|
)
|
|
|
|
|
|
def get_up_block(
|
|
up_block_type,
|
|
num_layers,
|
|
in_channels,
|
|
out_channels,
|
|
prev_output_channel,
|
|
temb_channels,
|
|
add_upsample,
|
|
resnet_eps,
|
|
resnet_act_fn,
|
|
attn_num_head_channels,
|
|
cross_attention_dim=None,
|
|
):
|
|
up_block_type = up_block_type[7:] if up_block_type.startswith(
|
|
"UNetRes") else up_block_type
|
|
if up_block_type == "UpBlock2D":
|
|
return UpBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=temb_channels,
|
|
add_upsample=add_upsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
)
|
|
elif up_block_type == "CrossAttnUpBlock2D":
|
|
if cross_attention_dim is None:
|
|
raise ValueError(
|
|
"cross_attention_dim must be specified for CrossAttnUpBlock2D")
|
|
return CrossAttnUpBlock2D(
|
|
num_layers=num_layers,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
prev_output_channel=prev_output_channel,
|
|
temb_channels=temb_channels,
|
|
add_upsample=add_upsample,
|
|
resnet_eps=resnet_eps,
|
|
resnet_act_fn=resnet_act_fn,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attn_num_head_channels=attn_num_head_channels,
|
|
)
|
|
raise ValueError(f"{up_block_type} does not exist.")
|