abfa4ad8bc
Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.
673 lines
24 KiB
Python
673 lines
24 KiB
Python
from __future__ import annotations
|
|
import math
|
|
import psutil
|
|
import platform
|
|
|
|
import torch
|
|
from torch import einsum
|
|
|
|
from ldm.util import default
|
|
from einops import rearrange
|
|
|
|
from modules import shared, errors, devices, sub_quadratic_attention
|
|
from modules.hypernetworks import hypernetwork
|
|
|
|
import ldm.modules.attention
|
|
import ldm.modules.diffusionmodules.model
|
|
|
|
import sgm.modules.attention
|
|
import sgm.modules.diffusionmodules.model
|
|
|
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
|
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
|
|
|
|
|
class SdOptimization:
|
|
name: str = None
|
|
label: str | None = None
|
|
cmd_opt: str | None = None
|
|
priority: int = 0
|
|
|
|
def title(self):
|
|
if self.label is None:
|
|
return self.name
|
|
|
|
return f"{self.name} - {self.label}"
|
|
|
|
def is_available(self):
|
|
return True
|
|
|
|
def apply(self):
|
|
pass
|
|
|
|
def undo(self):
|
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
|
|
|
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
|
|
|
|
|
class SdOptimizationXformers(SdOptimization):
|
|
name = "xformers"
|
|
cmd_opt = "xformers"
|
|
priority = 100
|
|
|
|
def is_available(self):
|
|
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
|
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
|
|
|
|
|
class SdOptimizationSdpNoMem(SdOptimization):
|
|
name = "sdp-no-mem"
|
|
label = "scaled dot product without memory efficient attention"
|
|
cmd_opt = "opt_sdp_no_mem_attention"
|
|
priority = 80
|
|
|
|
def is_available(self):
|
|
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
|
|
|
|
|
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|
name = "sdp"
|
|
label = "scaled dot product"
|
|
cmd_opt = "opt_sdp_attention"
|
|
priority = 70
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
|
|
|
|
|
class SdOptimizationSubQuad(SdOptimization):
|
|
name = "sub-quadratic"
|
|
cmd_opt = "opt_sub_quad_attention"
|
|
priority = 10
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
|
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
|
|
|
|
|
class SdOptimizationV1(SdOptimization):
|
|
name = "V1"
|
|
label = "original v1"
|
|
cmd_opt = "opt_split_attention_v1"
|
|
priority = 10
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
|
|
|
|
class SdOptimizationInvokeAI(SdOptimization):
|
|
name = "InvokeAI"
|
|
cmd_opt = "opt_split_attention_invokeai"
|
|
|
|
@property
|
|
def priority(self):
|
|
return 1000 if not torch.cuda.is_available() else 10
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
|
|
|
|
|
class SdOptimizationDoggettx(SdOptimization):
|
|
name = "Doggettx"
|
|
cmd_opt = "opt_split_attention"
|
|
priority = 90
|
|
|
|
def apply(self):
|
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
|
|
|
|
def list_optimizers(res):
|
|
res.extend([
|
|
SdOptimizationXformers(),
|
|
SdOptimizationSdpNoMem(),
|
|
SdOptimizationSdp(),
|
|
SdOptimizationSubQuad(),
|
|
SdOptimizationV1(),
|
|
SdOptimizationInvokeAI(),
|
|
SdOptimizationDoggettx(),
|
|
])
|
|
|
|
|
|
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
|
try:
|
|
import xformers.ops
|
|
shared.xformers_available = True
|
|
except Exception:
|
|
errors.report("Cannot import xformers", exc_info=True)
|
|
|
|
|
|
def get_available_vram():
|
|
if shared.device.type == 'cuda':
|
|
stats = torch.cuda.memory_stats(shared.device)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
return mem_free_total
|
|
else:
|
|
return psutil.virtual_memory().available
|
|
|
|
|
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
|
h = self.heads
|
|
|
|
q_in = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k_in = self.to_k(context_k)
|
|
v_in = self.to_v(context_v)
|
|
del context, context_k, context_v, x
|
|
|
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
|
del q_in, k_in, v_in
|
|
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k, v = q.float(), k.float(), v.float()
|
|
|
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
for i in range(0, q.shape[0], 2):
|
|
end = i + 2
|
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
s1 *= self.scale
|
|
|
|
s2 = s1.softmax(dim=-1)
|
|
del s1
|
|
|
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
del s2
|
|
del q, k, v
|
|
|
|
r1 = r1.to(dtype)
|
|
|
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
del r1
|
|
|
|
return self.to_out(r2)
|
|
|
|
|
|
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
|
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
|
h = self.heads
|
|
|
|
q_in = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k_in = self.to_k(context_k)
|
|
v_in = self.to_v(context_v)
|
|
|
|
dtype = q_in.dtype
|
|
if shared.opts.upcast_attn:
|
|
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
|
|
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
|
k_in = k_in * self.scale
|
|
|
|
del context, x
|
|
|
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
|
del q_in, k_in, v_in
|
|
|
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
|
|
mem_free_total = get_available_vram()
|
|
|
|
gb = 1024 ** 3
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
modifier = 3 if q.element_size() == 2 else 2.5
|
|
mem_required = tensor_size * modifier
|
|
steps = 1
|
|
|
|
if mem_required > mem_free_total:
|
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
|
|
if steps > 64:
|
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
|
|
slice_size = q.shape[1] // steps
|
|
for i in range(0, q.shape[1], slice_size):
|
|
end = min(i + slice_size, q.shape[1])
|
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
|
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
del s1
|
|
|
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
del s2
|
|
|
|
del q, k, v
|
|
|
|
r1 = r1.to(dtype)
|
|
|
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
del r1
|
|
|
|
return self.to_out(r2)
|
|
|
|
|
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
|
|
|
|
|
def einsum_op_compvis(q, k, v):
|
|
s = einsum('b i d, b j d -> b i j', q, k)
|
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
|
return einsum('b i j, b j d -> b i d', s, v)
|
|
|
|
|
|
def einsum_op_slice_0(q, k, v, slice_size):
|
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
for i in range(0, q.shape[0], slice_size):
|
|
end = i + slice_size
|
|
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
|
return r
|
|
|
|
|
|
def einsum_op_slice_1(q, k, v, slice_size):
|
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
for i in range(0, q.shape[1], slice_size):
|
|
end = i + slice_size
|
|
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
|
return r
|
|
|
|
|
|
def einsum_op_mps_v1(q, k, v):
|
|
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
|
return einsum_op_compvis(q, k, v)
|
|
else:
|
|
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
|
if slice_size % 4096 == 0:
|
|
slice_size -= 1
|
|
return einsum_op_slice_1(q, k, v, slice_size)
|
|
|
|
|
|
def einsum_op_mps_v2(q, k, v):
|
|
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
|
return einsum_op_compvis(q, k, v)
|
|
else:
|
|
return einsum_op_slice_0(q, k, v, 1)
|
|
|
|
|
|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
|
if size_mb <= max_tensor_mb:
|
|
return einsum_op_compvis(q, k, v)
|
|
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
|
if div <= q.shape[0]:
|
|
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
|
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
|
|
|
|
|
def einsum_op_cuda(q, k, v):
|
|
stats = torch.cuda.memory_stats(q.device)
|
|
mem_active = stats['active_bytes.all.current']
|
|
mem_reserved = stats['reserved_bytes.all.current']
|
|
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
|
|
mem_free_torch = mem_reserved - mem_active
|
|
mem_free_total = mem_free_cuda + mem_free_torch
|
|
# Divide factor of safety as there's copying and fragmentation
|
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
|
|
|
|
|
def einsum_op(q, k, v):
|
|
if q.device.type == 'cuda':
|
|
return einsum_op_cuda(q, k, v)
|
|
|
|
if q.device.type == 'mps':
|
|
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
|
|
return einsum_op_mps_v1(q, k, v)
|
|
return einsum_op_mps_v2(q, k, v)
|
|
|
|
# Smaller slices are faster due to L2/L3/SLC caches.
|
|
# Tested on i7 with 8MB L3 cache.
|
|
return einsum_op_tensor_mem(q, k, v, 32)
|
|
|
|
|
|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
|
h = self.heads
|
|
|
|
q = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k = self.to_k(context_k)
|
|
v = self.to_v(context_v)
|
|
del context, context_k, context_v, x
|
|
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
|
|
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
|
k = k * self.scale
|
|
|
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
|
r = einsum_op(q, k, v)
|
|
r = r.to(dtype)
|
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
|
|
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
|
|
|
|
|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
|
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
|
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
|
|
|
h = self.heads
|
|
|
|
q = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k = self.to_k(context_k)
|
|
v = self.to_v(context_v)
|
|
del context, context_k, context_v, x
|
|
|
|
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
|
|
|
if q.device.type == 'mps':
|
|
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
|
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k = q.float(), k.float()
|
|
|
|
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
|
|
|
x = x.to(dtype)
|
|
|
|
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
|
|
|
out_proj, dropout = self.to_out
|
|
x = out_proj(x)
|
|
x = dropout(x)
|
|
|
|
return x
|
|
|
|
|
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
|
batch_x_heads, q_tokens, _ = q.shape
|
|
_, k_tokens, _ = k.shape
|
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
|
|
|
if chunk_threshold is None:
|
|
if q.device.type == 'mps':
|
|
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
|
else:
|
|
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
|
elif chunk_threshold == 0:
|
|
chunk_threshold_bytes = None
|
|
else:
|
|
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
|
|
|
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
|
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
|
elif kv_chunk_size_min == 0:
|
|
kv_chunk_size_min = None
|
|
|
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
|
# i.e. send it down the unchunked fast-path
|
|
kv_chunk_size = k_tokens
|
|
|
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
|
return sub_quadratic_attention.efficient_dot_product_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
query_chunk_size=q_chunk_size,
|
|
kv_chunk_size=kv_chunk_size,
|
|
kv_chunk_size_min = kv_chunk_size_min,
|
|
use_checkpoint=use_checkpoint,
|
|
)
|
|
|
|
|
|
def get_xformers_flash_attention_op(q, k, v):
|
|
if not shared.cmd_opts.xformers_flash_attention:
|
|
return None
|
|
|
|
try:
|
|
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
|
fw, bw = flash_attention_op
|
|
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
|
return flash_attention_op
|
|
except Exception as e:
|
|
errors.display_once(e, "enabling flash attention")
|
|
|
|
return None
|
|
|
|
|
|
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
|
h = self.heads
|
|
q_in = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k_in = self.to_k(context_k)
|
|
v_in = self.to_v(context_v)
|
|
|
|
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
|
del q_in, k_in, v_in
|
|
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k, v = q.float(), k.float(), v.float()
|
|
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
|
|
|
out = out.to(dtype)
|
|
|
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
|
return self.to_out(out)
|
|
|
|
|
|
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
|
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
|
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
|
batch_size, sequence_length, inner_dim = x.shape
|
|
|
|
if mask is not None:
|
|
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
|
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
|
|
|
h = self.heads
|
|
q_in = self.to_q(x)
|
|
context = default(context, x)
|
|
|
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
|
k_in = self.to_k(context_k)
|
|
v_in = self.to_v(context_v)
|
|
|
|
head_dim = inner_dim // h
|
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
|
|
|
del q_in, k_in, v_in
|
|
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k, v = q.float(), k.float(), v.float()
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
|
hidden_states = hidden_states.to(dtype)
|
|
|
|
# linear proj
|
|
hidden_states = self.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = self.to_out[1](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
|
return scaled_dot_product_attention_forward(self, x, context, mask)
|
|
|
|
|
|
def cross_attention_attnblock_forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q1 = self.q(h_)
|
|
k1 = self.k(h_)
|
|
v = self.v(h_)
|
|
|
|
# compute attention
|
|
b, c, h, w = q1.shape
|
|
|
|
q2 = q1.reshape(b, c, h*w)
|
|
del q1
|
|
|
|
q = q2.permute(0, 2, 1) # b,hw,c
|
|
del q2
|
|
|
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
|
del k1
|
|
|
|
h_ = torch.zeros_like(k, device=q.device)
|
|
|
|
mem_free_total = get_available_vram()
|
|
|
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
mem_required = tensor_size * 2.5
|
|
steps = 1
|
|
|
|
if mem_required > mem_free_total:
|
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
|
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
for i in range(0, q.shape[1], slice_size):
|
|
end = i + slice_size
|
|
|
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
w2 = w1 * (int(c)**(-0.5))
|
|
del w1
|
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
|
del w2
|
|
|
|
# attend to values
|
|
v1 = v.reshape(b, c, h*w)
|
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
del w3
|
|
|
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
del v1, w4
|
|
|
|
h2 = h_.reshape(b, c, h, w)
|
|
del h_
|
|
|
|
h3 = self.proj_out(h2)
|
|
del h2
|
|
|
|
h3 += x
|
|
|
|
return h3
|
|
|
|
|
|
def xformers_attnblock_forward(self, x):
|
|
try:
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
b, c, h, w = q.shape
|
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k = q.float(), k.float()
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
|
out = out.to(dtype)
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
out = self.proj_out(out)
|
|
return x + out
|
|
except NotImplementedError:
|
|
return cross_attention_attnblock_forward(self, x)
|
|
|
|
|
|
def sdp_attnblock_forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
b, c, h, w = q.shape
|
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
|
dtype = q.dtype
|
|
if shared.opts.upcast_attn:
|
|
q, k, v = q.float(), k.float(), v.float()
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
|
|
out = out.to(dtype)
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
out = self.proj_out(out)
|
|
return x + out
|
|
|
|
|
|
def sdp_no_mem_attnblock_forward(self, x):
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
|
return sdp_attnblock_forward(self, x)
|
|
|
|
|
|
def sub_quad_attnblock_forward(self, x):
|
|
h_ = x
|
|
h_ = self.norm(h_)
|
|
q = self.q(h_)
|
|
k = self.k(h_)
|
|
v = self.v(h_)
|
|
b, c, h, w = q.shape
|
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
|
q = q.contiguous()
|
|
k = k.contiguous()
|
|
v = v.contiguous()
|
|
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
|
out = self.proj_out(out)
|
|
return x + out
|