my-sd/modules/sd_hijack.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

221 lines
7.7 KiB
Python
Raw Normal View History

import torch
from torch.nn.functional import silu
from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
2022-11-26 13:10:46 +00:00
import ldm.modules.attention
2022-09-13 11:29:56 +00:00
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
2023-09-11 18:17:28 +00:00
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
2022-11-26 13:10:46 +00:00
import ldm.modules.encoders.modules
2022-09-13 11:29:56 +00:00
2023-07-12 20:52:43 +00:00
import sgm.modules.attention
import sgm.modules.diffusionmodules.model
import sgm.modules.diffusionmodules.openaimodel
import sgm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
2022-09-13 11:29:56 +00:00
2022-11-26 13:10:46 +00:00
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
2023-12-02 15:03:34 +00:00
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
new_optimizers = [x for x in new_optimizers if x.is_available()]
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
optimizers.clear()
optimizers.extend(new_optimizers)
2023-05-27 12:47:33 +00:00
def apply_optimizations(option=None):
2024-01-24 19:42:03 +00:00
return
2022-09-13 11:29:56 +00:00
def undo_optimizations():
2024-01-24 19:47:21 +00:00
return
2023-07-12 20:52:43 +00:00
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
checkpoints to be added when not training (there's a warning)"""
pass
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
#Return the loss, as mean if specified
return loss.mean() if mean else loss
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
sd_model._old_get_loss = sd_model.get_loss
sd_model.get_loss = MethodType(weighted_loss, sd_model)
#Run the standard forward function, but with the patched 'get_loss'
return sd_model.forward(x, c, *args, **kwargs)
finally:
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
2023-05-10 04:52:45 +00:00
except AttributeError:
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
del sd_model._old_get_loss
def apply_weighted_forward(sd_model):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
2023-05-10 04:52:45 +00:00
except AttributeError:
pass
class StableDiffusionModelHijack:
fixes = None
layers = None
circular_enabled = False
clip = None
2023-01-04 13:04:38 +00:00
optimization_method = None
def __init__(self):
import modules.textual_inversion.textual_inversion
self.extra_generation_params = {}
self.comments = []
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
2022-11-30 02:13:17 +00:00
2023-05-27 12:47:33 +00:00
def apply_optimizations(self, option=None):
2024-01-25 13:05:23 +00:00
pass
2023-11-05 16:32:21 +00:00
def convert_sdxl_to_ssd(self, m):
2024-01-25 13:05:23 +00:00
pass
def hijack(self, m):
2024-01-25 13:47:54 +00:00
pass
2023-05-27 12:47:33 +00:00
def undo_hijack(self, m):
2024-01-25 13:04:27 +00:00
pass
def apply_circular(self, enable):
2024-01-28 04:04:10 +00:00
pass
def clear_comments(self):
self.comments = []
self.extra_generation_params = {}
2024-01-25 16:47:34 +00:00
def get_prompt_lengths(self, text, cond_stage_model):
_, token_count = cond_stage_model.process_texts([text])
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
def redo_hijack(self, m):
2024-01-25 13:04:27 +00:00
pass
class EmbeddingsWithFixes(torch.nn.Module):
2023-07-29 12:15:06 +00:00
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
2023-07-29 12:15:06 +00:00
self.textual_inversion_key = textual_inversion_key
2024-02-02 06:57:49 +00:00
self.weight = self.wrapped.weight
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
2023-07-29 12:15:06 +00:00
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
return torch.stack(vecs)
def add_circular_option_to_conv_2d():
conv2d_constructor = torch.nn.Conv2d.__init__
def conv2d_constructor_circular(self, *args, **kwargs):
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
model_hijack = StableDiffusionModelHijack()
def register_buffer(self, name, attr):
"""
Fix register buffer bug for Mac OS.
"""
if type(attr) == torch.Tensor:
if attr.device != devices.device:
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer