2022-09-03 09:08:45 +00:00
|
|
|
import torch
|
2022-10-02 21:31:19 +00:00
|
|
|
from torch.nn.functional import silu
|
2023-01-12 14:03:46 +00:00
|
|
|
from types import MethodType
|
2022-09-03 09:08:45 +00:00
|
|
|
|
2023-09-11 18:17:28 +00:00
|
|
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
2022-11-26 13:45:57 +00:00
|
|
|
from modules.hypernetworks import hypernetwork
|
2022-12-10 06:17:39 +00:00
|
|
|
from modules.shared import cmd_opts
|
2023-09-23 09:51:41 +00:00
|
|
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
2022-11-26 13:10:46 +00:00
|
|
|
|
2022-09-04 22:41:20 +00:00
|
|
|
import ldm.modules.attention
|
2022-09-13 11:29:56 +00:00
|
|
|
import ldm.modules.diffusionmodules.model
|
2022-12-02 12:47:02 +00:00
|
|
|
import ldm.modules.diffusionmodules.openaimodel
|
2023-09-11 18:17:28 +00:00
|
|
|
import ldm.models.diffusion.ddpm
|
2022-11-11 15:20:18 +00:00
|
|
|
import ldm.models.diffusion.ddim
|
|
|
|
import ldm.models.diffusion.plms
|
2022-11-26 13:10:46 +00:00
|
|
|
import ldm.modules.encoders.modules
|
2022-09-13 11:29:56 +00:00
|
|
|
|
2023-07-12 20:52:43 +00:00
|
|
|
import sgm.modules.attention
|
|
|
|
import sgm.modules.diffusionmodules.model
|
|
|
|
import sgm.modules.diffusionmodules.openaimodel
|
|
|
|
import sgm.modules.encoders.modules
|
|
|
|
|
2022-10-02 12:03:39 +00:00
|
|
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|
|
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
|
|
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
2022-09-13 11:29:56 +00:00
|
|
|
|
2022-11-26 13:10:46 +00:00
|
|
|
# new memory efficient cross attention blocks do not support hypernets and we already
|
|
|
|
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
|
|
|
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
|
|
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
|
|
|
|
|
|
|
# silence new console spam from SD2
|
2023-07-31 21:24:48 +00:00
|
|
|
ldm.modules.attention.print = shared.ldm_print
|
|
|
|
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
|
|
|
ldm.util.print = shared.ldm_print
|
|
|
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
2022-10-15 13:59:37 +00:00
|
|
|
|
2023-05-18 19:48:28 +00:00
|
|
|
optimizers = []
|
|
|
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
|
|
|
|
2023-12-02 16:35:47 +00:00
|
|
|
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
|
|
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
|
|
|
|
|
|
|
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
|
|
|
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
2023-12-02 15:03:34 +00:00
|
|
|
|
|
|
|
|
2023-05-18 19:48:28 +00:00
|
|
|
def list_optimizers():
|
|
|
|
new_optimizers = script_callbacks.list_optimizers_callback()
|
|
|
|
|
|
|
|
new_optimizers = [x for x in new_optimizers if x.is_available()]
|
|
|
|
|
2023-05-19 07:05:07 +00:00
|
|
|
new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
|
2023-05-18 19:48:28 +00:00
|
|
|
|
|
|
|
optimizers.clear()
|
|
|
|
optimizers.extend(new_optimizers)
|
|
|
|
|
2022-12-10 06:14:30 +00:00
|
|
|
|
2023-05-27 12:47:33 +00:00
|
|
|
def apply_optimizations(option=None):
|
2024-01-24 19:42:03 +00:00
|
|
|
return
|
2022-09-13 11:29:56 +00:00
|
|
|
|
|
|
|
|
2022-10-02 12:03:39 +00:00
|
|
|
def undo_optimizations():
|
2024-01-24 19:47:21 +00:00
|
|
|
return
|
2023-07-12 20:52:43 +00:00
|
|
|
|
2022-09-03 09:08:45 +00:00
|
|
|
|
2023-01-19 17:39:03 +00:00
|
|
|
def fix_checkpoint():
|
|
|
|
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
|
|
|
checkpoints to be added when not training (there's a warning)"""
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-01-12 14:03:46 +00:00
|
|
|
def weighted_loss(sd_model, pred, target, mean=True):
|
|
|
|
#Calculate the weight normally, but ignore the mean
|
|
|
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
2023-05-11 15:28:15 +00:00
|
|
|
|
2023-01-12 14:03:46 +00:00
|
|
|
#Check if we have weights available
|
|
|
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
|
|
|
if weight is not None:
|
|
|
|
loss *= weight
|
2023-05-11 15:28:15 +00:00
|
|
|
|
2023-01-12 14:03:46 +00:00
|
|
|
#Return the loss, as mean if specified
|
|
|
|
return loss.mean() if mean else loss
|
|
|
|
|
|
|
|
def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|
|
|
try:
|
|
|
|
#Temporarily append weights to a place accessible during loss calc
|
|
|
|
sd_model._custom_loss_weight = w
|
2023-05-11 15:28:15 +00:00
|
|
|
|
2023-01-12 14:03:46 +00:00
|
|
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
|
|
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
|
|
|
if not hasattr(sd_model, '_old_get_loss'):
|
|
|
|
sd_model._old_get_loss = sd_model.get_loss
|
|
|
|
sd_model.get_loss = MethodType(weighted_loss, sd_model)
|
|
|
|
|
|
|
|
#Run the standard forward function, but with the patched 'get_loss'
|
|
|
|
return sd_model.forward(x, c, *args, **kwargs)
|
|
|
|
finally:
|
|
|
|
try:
|
|
|
|
#Delete temporary weights if appended
|
|
|
|
del sd_model._custom_loss_weight
|
2023-05-10 04:52:45 +00:00
|
|
|
except AttributeError:
|
2023-01-12 14:03:46 +00:00
|
|
|
pass
|
2023-05-11 15:28:15 +00:00
|
|
|
|
2023-01-12 14:03:46 +00:00
|
|
|
#If we have an old loss function, reset the loss function to the original one
|
|
|
|
if hasattr(sd_model, '_old_get_loss'):
|
|
|
|
sd_model.get_loss = sd_model._old_get_loss
|
|
|
|
del sd_model._old_get_loss
|
|
|
|
|
|
|
|
def apply_weighted_forward(sd_model):
|
|
|
|
#Add new function 'weighted_forward' that can be called to calc weighted loss
|
|
|
|
sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
|
|
|
|
|
|
|
|
def undo_weighted_forward(sd_model):
|
|
|
|
try:
|
|
|
|
del sd_model.weighted_forward
|
2023-05-10 04:52:45 +00:00
|
|
|
except AttributeError:
|
2023-01-12 14:03:46 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
2022-09-03 09:08:45 +00:00
|
|
|
class StableDiffusionModelHijack:
|
|
|
|
fixes = None
|
2022-09-05 00:25:37 +00:00
|
|
|
layers = None
|
|
|
|
circular_enabled = False
|
2022-09-27 19:56:18 +00:00
|
|
|
clip = None
|
2023-01-04 13:04:38 +00:00
|
|
|
optimization_method = None
|
2022-09-03 09:08:45 +00:00
|
|
|
|
2023-01-08 06:37:33 +00:00
|
|
|
def __init__(self):
|
2023-08-04 06:09:09 +00:00
|
|
|
import modules.textual_inversion.textual_inversion
|
|
|
|
|
2023-07-15 05:41:22 +00:00
|
|
|
self.extra_generation_params = {}
|
|
|
|
self.comments = []
|
|
|
|
|
2023-08-04 06:09:09 +00:00
|
|
|
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
2023-01-08 06:37:33 +00:00
|
|
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
2022-11-30 02:13:17 +00:00
|
|
|
|
2023-05-27 12:47:33 +00:00
|
|
|
def apply_optimizations(self, option=None):
|
2024-01-25 13:05:23 +00:00
|
|
|
pass
|
2023-11-05 16:32:21 +00:00
|
|
|
|
2023-11-05 16:46:20 +00:00
|
|
|
def convert_sdxl_to_ssd(self, m):
|
2024-01-25 13:05:23 +00:00
|
|
|
pass
|
2023-05-23 15:02:09 +00:00
|
|
|
|
2023-01-08 06:37:33 +00:00
|
|
|
def hijack(self, m):
|
2024-01-25 13:47:54 +00:00
|
|
|
pass
|
2023-05-27 12:47:33 +00:00
|
|
|
|
2022-09-29 12:40:28 +00:00
|
|
|
def undo_hijack(self, m):
|
2024-01-25 13:04:27 +00:00
|
|
|
pass
|
2022-11-01 07:01:49 +00:00
|
|
|
|
2022-09-05 00:25:37 +00:00
|
|
|
def apply_circular(self, enable):
|
2024-01-28 04:04:10 +00:00
|
|
|
pass
|
2022-09-05 00:25:37 +00:00
|
|
|
|
2022-10-07 21:48:34 +00:00
|
|
|
def clear_comments(self):
|
|
|
|
self.comments = []
|
2023-07-15 05:41:22 +00:00
|
|
|
self.extra_generation_params = {}
|
2022-10-07 21:48:34 +00:00
|
|
|
|
2024-01-25 16:47:34 +00:00
|
|
|
def get_prompt_lengths(self, text, cond_stage_model):
|
|
|
|
_, token_count = cond_stage_model.process_texts([text])
|
|
|
|
return token_count, cond_stage_model.get_target_prompt_token_count(token_count)
|
2022-09-03 09:08:45 +00:00
|
|
|
|
2023-05-18 19:48:28 +00:00
|
|
|
def redo_hijack(self, m):
|
2024-01-25 13:04:27 +00:00
|
|
|
pass
|
2023-05-18 19:48:28 +00:00
|
|
|
|
2022-09-03 09:08:45 +00:00
|
|
|
|
|
|
|
class EmbeddingsWithFixes(torch.nn.Module):
|
2023-07-29 12:15:06 +00:00
|
|
|
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
2022-09-03 09:08:45 +00:00
|
|
|
super().__init__()
|
|
|
|
self.wrapped = wrapped
|
|
|
|
self.embeddings = embeddings
|
2023-07-29 12:15:06 +00:00
|
|
|
self.textual_inversion_key = textual_inversion_key
|
2022-09-03 09:08:45 +00:00
|
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
|
batch_fixes = self.embeddings.fixes
|
|
|
|
self.embeddings.fixes = None
|
|
|
|
|
|
|
|
inputs_embeds = self.wrapped(input_ids)
|
|
|
|
|
2022-10-02 12:03:39 +00:00
|
|
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
|
|
|
return inputs_embeds
|
|
|
|
|
|
|
|
vecs = []
|
|
|
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
|
|
for offset, embedding in fixes:
|
2023-07-29 12:15:06 +00:00
|
|
|
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
|
|
|
emb = devices.cond_cast_unet(vec)
|
2022-10-15 13:59:37 +00:00
|
|
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
|
|
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
2022-10-02 12:03:39 +00:00
|
|
|
|
|
|
|
vecs.append(tensor)
|
2022-09-03 09:08:45 +00:00
|
|
|
|
2022-10-02 12:03:39 +00:00
|
|
|
return torch.stack(vecs)
|
2022-09-03 09:08:45 +00:00
|
|
|
|
|
|
|
|
2022-09-04 23:16:36 +00:00
|
|
|
def add_circular_option_to_conv_2d():
|
|
|
|
conv2d_constructor = torch.nn.Conv2d.__init__
|
2022-09-04 22:41:20 +00:00
|
|
|
|
2022-09-04 23:16:36 +00:00
|
|
|
def conv2d_constructor_circular(self, *args, **kwargs):
|
|
|
|
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
2022-09-04 22:41:20 +00:00
|
|
|
|
2022-09-04 23:16:36 +00:00
|
|
|
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
2022-09-04 22:41:20 +00:00
|
|
|
|
|
|
|
|
2022-09-03 09:08:45 +00:00
|
|
|
model_hijack = StableDiffusionModelHijack()
|
2022-11-11 15:20:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def register_buffer(self, name, attr):
|
|
|
|
"""
|
|
|
|
Fix register buffer bug for Mac OS.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if type(attr) == torch.Tensor:
|
|
|
|
if attr.device != devices.device:
|
2022-11-12 07:17:55 +00:00
|
|
|
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
2022-11-11 15:20:18 +00:00
|
|
|
|
|
|
|
setattr(self, name, attr)
|
|
|
|
|
|
|
|
|
|
|
|
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
|
|
|
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|