my-sd/extensions-builtin/Lora/networks.py
Kohaku-Blueleaf 393c19bbcf
Remove Lycoris back compact (#41)
Forge use totally different mechanism to handle lora/lycoris models
I think it is good to remove the back compact things so "lyco" alias can be used by legacy lycoris extensions. (some other extension may rely on it)
2024-02-06 11:17:15 +08:00

204 lines
5.7 KiB
Python

import os
import re
import lora_patches
import functools
import network
import torch
from typing import Union
from modules import shared, sd_models, errors, scripts
from ldm_patched.modules.utils import load_torch_file
from ldm_patched.modules.sd import load_lora_for_models
@functools.lru_cache(maxsize=5)
def load_lora_state_dict(filename):
return load_torch_file(filename, safe_load=True)
def convert_diffusers_name_to_compvis(key, is_sd2):
pass
def assign_network_names_to_compvis_modules(sd_model):
pass
def load_network(name, network_on_disk):
pass
def purge_networks_from_memory():
pass
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
global lora_state_dict_cache
current_sd = sd_models.model_data.get_sd_model()
if current_sd is None:
return
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
compiled_lora_targets = []
for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
compiled_lora_targets.append([a.filename, b, c])
compiled_lora_targets_hash = str(compiled_lora_targets)
if current_sd.current_lora_hash == compiled_lora_targets_hash:
return
current_sd.current_lora_hash = compiled_lora_targets_hash
current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
current_sd.forge_objects.clip = current_sd.forge_objects_original.clip
for filename, strength_model, strength_clip in compiled_lora_targets:
lora_sd = load_lora_state_dict(filename)
current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip)
current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
return
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
pass
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
pass
def network_forward(org_module, input, original_forward):
pass
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
pass
def network_Linear_forward(self, input):
pass
def network_Linear_load_state_dict(self, *args, **kwargs):
pass
def network_Conv2d_forward(self, input):
pass
def network_Conv2d_load_state_dict(self, *args, **kwargs):
pass
def network_GroupNorm_forward(self, input):
pass
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
pass
def network_LayerNorm_forward(self, input):
pass
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
pass
def network_MultiheadAttention_forward(self, *args, **kwargs):
pass
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
pass
def list_available_networks():
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
try:
entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
continue
available_networks[name] = entry
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
available_network_aliases[name] = entry
available_network_aliases[entry.alias] = entry
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k in params:
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_network_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
originals: lora_patches.LoraPatches = None
extra_network_lora = None
available_networks = {}
available_network_aliases = {}
loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
list_available_networks()