import os import re import lora_patches import functools import network import torch from typing import Union from modules import shared, sd_models, errors, scripts from ldm_patched.modules.utils import load_torch_file from ldm_patched.modules.sd import load_lora_for_models @functools.lru_cache(maxsize=5) def load_lora_state_dict(filename): return load_torch_file(filename, safe_load=True) def convert_diffusers_name_to_compvis(key, is_sd2): pass def assign_network_names_to_compvis_modules(sd_model): pass def load_network(name, network_on_disk): pass def purge_networks_from_memory(): pass def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): global lora_state_dict_cache current_sd = sd_models.model_data.get_sd_model() if current_sd is None: return networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] compiled_lora_targets = [] for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): compiled_lora_targets.append([a.filename, b, c]) compiled_lora_targets_hash = str(compiled_lora_targets) if current_sd.current_lora_hash == compiled_lora_targets_hash: return current_sd.current_lora_hash = compiled_lora_targets_hash current_sd.forge_objects.unet = current_sd.forge_objects_original.unet current_sd.forge_objects.clip = current_sd.forge_objects_original.clip for filename, strength_model, strength_clip in compiled_lora_targets: lora_sd = load_lora_state_dict(filename) current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models( current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip) current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy() return def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): pass def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): pass def network_forward(org_module, input, original_forward): pass def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): pass def network_Linear_forward(self, input): pass def network_Linear_load_state_dict(self, *args, **kwargs): pass def network_Conv2d_forward(self, input): pass def network_Conv2d_load_state_dict(self, *args, **kwargs): pass def network_GroupNorm_forward(self, input): pass def network_GroupNorm_load_state_dict(self, *args, **kwargs): pass def network_LayerNorm_forward(self, input): pass def network_LayerNorm_load_state_dict(self, *args, **kwargs): pass def network_MultiheadAttention_forward(self, *args, **kwargs): pass def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): pass def list_available_networks(): available_networks.clear() available_network_aliases.clear() forbidden_network_aliases.clear() available_network_hash_lookup.clear() forbidden_network_aliases.update({"none": 1, "Addams": 1}) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue name = os.path.splitext(os.path.basename(filename))[0] try: entry = network.NetworkOnDisk(name, filename) except OSError: # should catch FileNotFoundError and PermissionError etc. errors.report(f"Failed to load network {name} from {filename}", exc_info=True) continue available_networks[name] = entry if entry.alias in available_network_aliases: forbidden_network_aliases[entry.alias.lower()] = 1 available_network_aliases[name] = entry available_network_aliases[entry.alias] = entry re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") def infotext_pasted(infotext, params): if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: return # if the other extension is active, it will handle those fields, no need to do anything added = [] for k in params: if not k.startswith("AddNet Model "): continue num = k[13:] if params.get("AddNet Module " + num) != "LoRA": continue name = params.get("AddNet Model " + num) if name is None: continue m = re_network_name.match(name) if m: name = m.group(1) multiplier = params.get("AddNet Weight A " + num, "1.0") added.append(f"") if added: params["Prompt"] += "\n" + "".join(added) originals: lora_patches.LoraPatches = None extra_network_lora = None available_networks = {} available_network_aliases = {} loaded_networks = [] loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} list_available_networks()