3d039591fe
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
346 lines
12 KiB
Python
346 lines
12 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
# This file is only for reference, and not used in the backend or runtime.
|
|
|
|
|
|
import torch
|
|
from torch import nn
|
|
from ldm_patched.ldm.modules.attention import CrossAttention
|
|
from inspect import isfunction
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def uniq(arr):
|
|
return{el: True for el in arr}.keys()
|
|
|
|
|
|
def default(val, d):
|
|
if exists(val):
|
|
return val
|
|
return d() if isfunction(d) else d
|
|
|
|
|
|
# feedforward
|
|
class GEGLU(nn.Module):
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
|
|
def forward(self, x):
|
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
return x * torch.nn.functional.gelu(gate)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
dim_out = default(dim_out, dim)
|
|
project_in = nn.Sequential(
|
|
nn.Linear(dim, inner_dim),
|
|
nn.GELU()
|
|
) if not glu else GEGLU(dim, inner_dim)
|
|
|
|
self.net = nn.Sequential(
|
|
project_in,
|
|
nn.Dropout(dropout),
|
|
nn.Linear(inner_dim, dim_out)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class GatedCrossAttentionDense(nn.Module):
|
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
|
super().__init__()
|
|
|
|
self.attn = CrossAttention(
|
|
query_dim=query_dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head)
|
|
self.ff = FeedForward(query_dim, glu=True)
|
|
|
|
self.norm1 = nn.LayerNorm(query_dim)
|
|
self.norm2 = nn.LayerNorm(query_dim)
|
|
|
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
|
|
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
|
# for example, when it is set to 0, then the entire model is same as
|
|
# original one
|
|
self.scale = 1
|
|
|
|
def forward(self, x, objs):
|
|
|
|
x = x + self.scale * \
|
|
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
|
|
x = x + self.scale * \
|
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
|
|
|
return x
|
|
|
|
|
|
class GatedSelfAttentionDense(nn.Module):
|
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
|
super().__init__()
|
|
|
|
# we need a linear projection since we need cat visual feature and obj
|
|
# feature
|
|
self.linear = nn.Linear(context_dim, query_dim)
|
|
|
|
self.attn = CrossAttention(
|
|
query_dim=query_dim,
|
|
context_dim=query_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head)
|
|
self.ff = FeedForward(query_dim, glu=True)
|
|
|
|
self.norm1 = nn.LayerNorm(query_dim)
|
|
self.norm2 = nn.LayerNorm(query_dim)
|
|
|
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
|
|
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
|
# for example, when it is set to 0, then the entire model is same as
|
|
# original one
|
|
self.scale = 1
|
|
|
|
def forward(self, x, objs):
|
|
|
|
N_visual = x.shape[1]
|
|
objs = self.linear(objs)
|
|
|
|
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
|
|
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
|
|
x = x + self.scale * \
|
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
|
|
|
return x
|
|
|
|
|
|
class GatedSelfAttentionDense2(nn.Module):
|
|
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
|
super().__init__()
|
|
|
|
# we need a linear projection since we need cat visual feature and obj
|
|
# feature
|
|
self.linear = nn.Linear(context_dim, query_dim)
|
|
|
|
self.attn = CrossAttention(
|
|
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
|
|
self.ff = FeedForward(query_dim, glu=True)
|
|
|
|
self.norm1 = nn.LayerNorm(query_dim)
|
|
self.norm2 = nn.LayerNorm(query_dim)
|
|
|
|
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
|
|
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
|
|
|
|
# this can be useful: we can externally change magnitude of tanh(alpha)
|
|
# for example, when it is set to 0, then the entire model is same as
|
|
# original one
|
|
self.scale = 1
|
|
|
|
def forward(self, x, objs):
|
|
|
|
B, N_visual, _ = x.shape
|
|
B, N_ground, _ = objs.shape
|
|
|
|
objs = self.linear(objs)
|
|
|
|
# sanity check
|
|
size_v = math.sqrt(N_visual)
|
|
size_g = math.sqrt(N_ground)
|
|
assert int(size_v) == size_v, "Visual tokens must be square rootable"
|
|
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
|
|
size_v = int(size_v)
|
|
size_g = int(size_g)
|
|
|
|
# select grounding token and resize it to visual token size as residual
|
|
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
|
|
:, N_visual:, :]
|
|
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
|
|
out = torch.nn.functional.interpolate(
|
|
out, (size_v, size_v), mode='bicubic')
|
|
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
|
|
|
|
# add residual to visual feature
|
|
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
|
|
x = x + self.scale * \
|
|
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
|
|
|
|
return x
|
|
|
|
|
|
class FourierEmbedder():
|
|
def __init__(self, num_freqs=64, temperature=100):
|
|
|
|
self.num_freqs = num_freqs
|
|
self.temperature = temperature
|
|
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
|
|
|
@torch.no_grad()
|
|
def __call__(self, x, cat_dim=-1):
|
|
"x: arbitrary shape of tensor. dim: cat dim"
|
|
out = []
|
|
for freq in self.freq_bands:
|
|
out.append(torch.sin(freq * x))
|
|
out.append(torch.cos(freq * x))
|
|
return torch.cat(out, cat_dim)
|
|
|
|
|
|
class PositionNet(nn.Module):
|
|
def __init__(self, in_dim, out_dim, fourier_freqs=8):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
|
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
|
|
|
|
self.linears = nn.Sequential(
|
|
nn.Linear(self.in_dim + self.position_dim, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, out_dim),
|
|
)
|
|
|
|
self.null_positive_feature = torch.nn.Parameter(
|
|
torch.zeros([self.in_dim]))
|
|
self.null_position_feature = torch.nn.Parameter(
|
|
torch.zeros([self.position_dim]))
|
|
|
|
def forward(self, boxes, masks, positive_embeddings):
|
|
B, N, _ = boxes.shape
|
|
dtype = self.linears[0].weight.dtype
|
|
masks = masks.unsqueeze(-1).to(dtype)
|
|
positive_embeddings = positive_embeddings.to(dtype)
|
|
|
|
# embedding position (it may includes padding as placeholder)
|
|
xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C
|
|
|
|
# learnable null embedding
|
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
|
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
|
|
|
# replace padding with learnable null embedding
|
|
positive_embeddings = positive_embeddings * \
|
|
masks + (1 - masks) * positive_null
|
|
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
|
|
|
objs = self.linears(
|
|
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
|
assert objs.shape == torch.Size([B, N, self.out_dim])
|
|
return objs
|
|
|
|
|
|
class Gligen(nn.Module):
|
|
def __init__(self, modules, position_net, key_dim):
|
|
super().__init__()
|
|
self.module_list = nn.ModuleList(modules)
|
|
self.position_net = position_net
|
|
self.key_dim = key_dim
|
|
self.max_objs = 30
|
|
self.current_device = torch.device("cpu")
|
|
|
|
def _set_position(self, boxes, masks, positive_embeddings):
|
|
objs = self.position_net(boxes, masks, positive_embeddings)
|
|
def func(x, extra_options):
|
|
key = extra_options["transformer_index"]
|
|
module = self.module_list[key]
|
|
return module(x, objs)
|
|
return func
|
|
|
|
def set_position(self, latent_image_shape, position_params, device):
|
|
batch, c, h, w = latent_image_shape
|
|
masks = torch.zeros([self.max_objs], device="cpu")
|
|
boxes = []
|
|
positive_embeddings = []
|
|
for p in position_params:
|
|
x1 = (p[4]) / w
|
|
y1 = (p[3]) / h
|
|
x2 = (p[4] + p[2]) / w
|
|
y2 = (p[3] + p[1]) / h
|
|
masks[len(boxes)] = 1.0
|
|
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
|
|
positive_embeddings += [p[0]]
|
|
append_boxes = []
|
|
append_conds = []
|
|
if len(boxes) < self.max_objs:
|
|
append_boxes = [torch.zeros(
|
|
[self.max_objs - len(boxes), 4], device="cpu")]
|
|
append_conds = [torch.zeros(
|
|
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
|
|
|
|
box_out = torch.cat(
|
|
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
|
|
masks = masks.unsqueeze(0).repeat(batch, 1)
|
|
conds = torch.cat(positive_embeddings +
|
|
append_conds).unsqueeze(0).repeat(batch, 1, 1)
|
|
return self._set_position(
|
|
box_out.to(device),
|
|
masks.to(device),
|
|
conds.to(device))
|
|
|
|
def set_empty(self, latent_image_shape, device):
|
|
batch, c, h, w = latent_image_shape
|
|
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
|
|
box_out = torch.zeros([self.max_objs, 4],
|
|
device="cpu").repeat(batch, 1, 1)
|
|
conds = torch.zeros([self.max_objs, self.key_dim],
|
|
device="cpu").repeat(batch, 1, 1)
|
|
return self._set_position(
|
|
box_out.to(device),
|
|
masks.to(device),
|
|
conds.to(device))
|
|
|
|
|
|
def load_gligen(sd):
|
|
sd_k = sd.keys()
|
|
output_list = []
|
|
key_dim = 768
|
|
for a in ["input_blocks", "middle_block", "output_blocks"]:
|
|
for b in range(20):
|
|
k_temp = filter(lambda k: "{}.{}.".format(a, b)
|
|
in k and ".fuser." in k, sd_k)
|
|
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
|
|
|
|
n_sd = {}
|
|
for k in k_temp:
|
|
n_sd[k[1]] = sd[k[0]]
|
|
if len(n_sd) > 0:
|
|
query_dim = n_sd["linear.weight"].shape[0]
|
|
key_dim = n_sd["linear.weight"].shape[1]
|
|
|
|
if key_dim == 768: # SD1.x
|
|
n_heads = 8
|
|
d_head = query_dim // n_heads
|
|
else:
|
|
d_head = 64
|
|
n_heads = query_dim // d_head
|
|
|
|
gated = GatedSelfAttentionDense(
|
|
query_dim, key_dim, n_heads, d_head)
|
|
gated.load_state_dict(n_sd, strict=False)
|
|
output_list.append(gated)
|
|
|
|
if "position_net.null_positive_feature" in sd_k:
|
|
in_dim = sd["position_net.null_positive_feature"].shape[0]
|
|
out_dim = sd["position_net.linears.4.weight"].shape[0]
|
|
|
|
class WeightsLoader(torch.nn.Module):
|
|
pass
|
|
w = WeightsLoader()
|
|
w.position_net = PositionNet(in_dim, out_dim)
|
|
w.load_state_dict(sd, strict=False)
|
|
|
|
gligen = Gligen(output_list, w.position_net, key_dim)
|
|
return gligen
|