2023-11-26 11:37:12 +03:00

346 lines
14 KiB

Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Original author: @tfernd Github:
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Callable
from functools import wraps, cache
import math
import torch.nn as nn
import random
from einops import rearrange
class HypertileParams:
depth = 0
layer_name = ""
tile_size: int = 0
swap_size: int = 0
aspect_ratio: float = 1.0
forward = None
enabled = False
# TODO add SD-XL layers
0: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
# SD 1.5 VAE
1: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
2: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
3: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
# XL layers, thanks for GitHub@gel-crabs for the help
0: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
# SD 1.5 VAE
1: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
2: [
# SD 1.5 U-Net (diffusers)
# SD 1.5 U-Net (ldm)
3 : [] # TODO - separate layers for SD-XL
RNG_INSTANCE = random.Random()
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
return ns[idx]
def set_hypertile_seed(seed: int) -> None:
def largest_tile_size_available(width: int, height: int) -> int:
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
gcd = math.gcd(width, height)
largest_tile_size_available = 1
while gcd % (largest_tile_size_available * 2) == 0:
largest_tile_size_available *= 2
return largest_tile_size_available
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
Finds h and w such that h*w = hw and h/w = aspect_ratio
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
# find h and w such that h*w = hw and h/w = aspect_ratio
if h * w != hw:
w_candidate = hw / h
# check if w is an integer
if not w_candidate.is_integer():
h_candidate = hw / w
# check if h is an integer
if not h_candidate.is_integer():
return iterative_closest_divisors(hw, aspect_ratio)
h = int(h_candidate)
w = int(w_candidate)
return h, w
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
def wrapper(*args, **kwargs):
if not params.enabled:
return params.forward(*args, **kwargs)
latent_tile_size = max(128, params.tile_size) // 8
x = args[0]
if x.ndim == 4:
b, c, h, w = x.shape
nh = random_divisor(h, latent_tile_size, params.swap_size)
nw = random_divisor(w, latent_tile_size, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
# U-Net
hw: int = x.size(1)
h, w = find_hw_candidates(hw, params.aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2 ** params.depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
if hypertile_layers is None:
if not enable:
hypertile_layers = {}
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
for depth in range(4):
for layer_name, module in model.named_modules():
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
params = HypertileParams()
module.__webui_hypertile_params = params
params.forward = module.forward
params.depth = depth
params.layer_name = layer_name
module.forward = self_attn_forward(params)
hypertile_layers[layer_name] = 1
model.__webui_hypertile_layers = hypertile_layers
aspect_ratio = width / height
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
for layer_name, module in model.named_modules():
if layer_name in hypertile_layers:
params = module.__webui_hypertile_params
params.tile_size = tile_size
params.swap_size = swap_size
params.aspect_ratio = aspect_ratio
params.enabled = enable and params.depth <= max_depth