From 9f1eb409b602b6282e49d9bb63a5ce6ad9c45d23 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 11:56:09 -0800 Subject: [PATCH] Update networks.py --- extensions-builtin/Lora/networks.py | 430 ++-------------------------- 1 file changed, 29 insertions(+), 401 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 83ea2802..a89c409f 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,14 +5,6 @@ import re import lora_patches import network -import network_lora -import network_glora -import network_hada -import network_ia3 -import network_lokr -import network_full -import network_norm -import network_oft import torch from typing import Union @@ -22,16 +14,7 @@ import modules.textual_inversion.textual_inversion as textual_inversion from lora_logger import logger -module_types = [ - network_lora.ModuleTypeLora(), - network_hada.ModuleTypeHada(), - network_ia3.ModuleTypeIa3(), - network_lokr.ModuleTypeLokr(), - network_full.ModuleTypeFull(), - network_norm.ModuleTypeNorm(), - network_glora.ModuleTypeGLora(), - network_oft.ModuleTypeOFT(), -] +module_types = [] re_digits = re.compile(r"\d+") @@ -118,449 +101,94 @@ def convert_diffusers_name_to_compvis(key, is_sd2): def assign_network_names_to_compvis_modules(sd_model): - network_layer_mapping = {} - - if shared.sd_model.is_sdxl: - for i, embedder in enumerate(shared.sd_model.conditioner.embedders): - if not hasattr(embedder, 'wrapped'): - continue - - for name, module in embedder.wrapped.named_modules(): - network_name = f'{i}_{name.replace(".", "_")}' - network_layer_mapping[network_name] = module - module.network_layer_name = network_name - else: - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - network_name = name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name - - for name, module in shared.sd_model.model.named_modules(): - network_name = name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name - - sd_model.network_layer_mapping = network_layer_mapping + pass def load_network(name, network_on_disk): - net = network.Network(name, network_on_disk) - net.mtime = os.path.getmtime(network_on_disk.filename) - - sd = sd_models.read_state_dict(network_on_disk.filename) - - # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 - if not hasattr(shared.sd_model, 'network_layer_mapping'): - assign_network_names_to_compvis_modules(shared.sd_model) - - keys_failed_to_match = {} - is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping - - matched_networks = {} - bundle_embeddings = {} - - for key_network, weight in sd.items(): - key_network_without_network_parts, _, network_part = key_network.partition(".") - - if key_network_without_network_parts == "bundle_emb": - emb_name, vec_name = network_part.split(".", 1) - emb_dict = bundle_embeddings.get(emb_name, {}) - if vec_name.split('.')[0] == 'string_to_param': - _, k2 = vec_name.split('.', 1) - emb_dict['string_to_param'] = {k2: weight} - else: - emb_dict[vec_name] = weight - bundle_embeddings[emb_name] = emb_dict - - key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - - if sd_module is None: - m = re_x_proj.match(key) - if m: - sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) - - # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - - # some SD1 Loras also have correct compvis keys - if sd_module is None: - key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - - # kohya_ss OFT module - elif sd_module is None and "oft_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - - # KohakuBlueLeaf OFT module - if sd_module is None and "oft_diag" in key: - key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") - key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) - - if sd_module is None: - keys_failed_to_match[key_network] = key - continue - - if key not in matched_networks: - matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) - - matched_networks[key].w[network_part] = weight - - for key, weights in matched_networks.items(): - net_module = None - for nettype in module_types: - net_module = nettype.create_module(net, weights) - if net_module is not None: - break - - if net_module is None: - raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") - - net.modules[key] = net_module - - embeddings = {} - for emb_name, data in bundle_embeddings.items(): - embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) - embedding.loaded = None - embeddings[emb_name] = embedding - - net.bundle_embeddings = embeddings - - if keys_failed_to_match: - logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") - - return net + pass def purge_networks_from_memory(): - while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: - name = next(iter(networks_in_memory)) - networks_in_memory.pop(name, None) - - devices.torch_gc() + pass def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): - emb_db = sd_hijack.model_hijack.embedding_db - already_loaded = {} - - for net in loaded_networks: - if net.name in names: - already_loaded[net.name] = net - for emb_name, embedding in net.bundle_embeddings.items(): - if embedding.loaded: - emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) - - loaded_networks.clear() + current_sd = sd_models.model_data.get_sd_model() + if current_sd is None: + return networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() - networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] - failed_to_load_networks = [] + compiled_lora_targets = [] + for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers): + compiled_lora_targets.append([a.filename, b, c]) - for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): - net = already_loaded.get(name, None) + compiled_lora_targets_hash = str(compiled_lora_targets) - if network_on_disk is not None: - if net is None: - net = networks_in_memory.get(name) + if current_sd.current_lora_hash == compiled_lora_targets_hash: + return - if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: - try: - net = load_network(name, network_on_disk) - - networks_in_memory.pop(name, None) - networks_in_memory[name] = net - except Exception as e: - errors.display(e, f"loading network {network_on_disk.filename}") - continue - - net.mentioned_name = name - - network_on_disk.read_hash() - - if net is None: - failed_to_load_networks.append(name) - logging.info(f"Couldn't find network with name {name}") - continue - - net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 - net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 - net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 - loaded_networks.append(net) - - for emb_name, embedding in net.bundle_embeddings.items(): - if embedding.loaded is None and emb_name in emb_db.word_embeddings: - logger.warning( - f'Skip bundle embedding: "{emb_name}"' - ' as it was already loaded from embeddings folder' - ) - continue - - embedding.loaded = False - if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: - embedding.loaded = True - emb_db.register_embedding(embedding, shared.sd_model) - else: - emb_db.skipped_embeddings[name] = embedding - - if failed_to_load_networks: - lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' - sd_hijack.model_hijack.comments.append(lora_not_found_message) - if shared.opts.lora_not_found_warning_console: - print(f'\n{lora_not_found_message}\n') - if shared.opts.lora_not_found_gradio_warning: - gr.Warning(lora_not_found_message) - - purge_networks_from_memory() + current_sd.current_lora_hash = compiled_lora_targets_hash + return def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - weights_backup = getattr(self, "network_weights_backup", None) - bias_backup = getattr(self, "network_bias_backup", None) - - if weights_backup is None and bias_backup is None: - return - - if weights_backup is not None: - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) - - if bias_backup is not None: - if isinstance(self, torch.nn.MultiheadAttention): - self.out_proj.bias.copy_(bias_backup) - else: - self.bias.copy_(bias_backup) - else: - if isinstance(self, torch.nn.MultiheadAttention): - self.out_proj.bias = None - else: - self.bias = None + pass def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of networks to the weights of torch layer self. - If weights already have this particular set of networks applied, does nothing. - If not, restores orginal weights from backup and alters weights according to networks. - """ - - network_layer_name = getattr(self, 'network_layer_name', None) - if network_layer_name is None: - return - - current_names = getattr(self, "network_current_names", ()) - wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) - - weights_backup = getattr(self, "network_weights_backup", None) - if weights_backup is None and wanted_names != (): - if current_names != (): - raise RuntimeError("no backup weights found and current weights are not unchanged") - - if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) - else: - weights_backup = self.weight.to(devices.cpu, copy=True) - - self.network_weights_backup = weights_backup - - bias_backup = getattr(self, "network_bias_backup", None) - if bias_backup is None: - if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: - bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) - elif getattr(self, 'bias', None) is not None: - bias_backup = self.bias.to(devices.cpu, copy=True) - else: - bias_backup = None - self.network_bias_backup = bias_backup - - if current_names != wanted_names: - network_restore_weights_from_backup(self) - - for net in loaded_networks: - module = net.modules.get(network_layer_name, None) - if module is not None and hasattr(self, 'weight'): - try: - with torch.no_grad(): - if getattr(self, 'fp16_weight', None) is None: - weight = self.weight - bias = self.bias - else: - weight = self.fp16_weight.clone().to(self.weight.device) - bias = getattr(self, 'fp16_bias', None) - if bias is not None: - bias = bias.clone().to(self.bias.device) - updown, ex_bias = module.calc_updown(weight) - - if len(weight.shape) == 4 and weight.shape[1] == 9: - # inpainting model. zero pad updown to make channel[1] 4 to 9 - updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - - self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) - if ex_bias is not None and hasattr(self, 'bias'): - if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) - else: - self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) - except RuntimeError as e: - logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - - continue - - module_q = net.modules.get(network_layer_name + "_q_proj", None) - module_k = net.modules.get(network_layer_name + "_k_proj", None) - module_v = net.modules.get(network_layer_name + "_v_proj", None) - module_out = net.modules.get(network_layer_name + "_out_proj", None) - - if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - try: - with torch.no_grad(): - updown_q, _ = module_q.calc_updown(self.in_proj_weight) - updown_k, _ = module_k.calc_updown(self.in_proj_weight) - updown_v, _ = module_v.calc_updown(self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) - - self.in_proj_weight += updown_qkv - self.out_proj.weight += updown_out - if ex_bias is not None: - if self.out_proj.bias is None: - self.out_proj.bias = torch.nn.Parameter(ex_bias) - else: - self.out_proj.bias += ex_bias - - except RuntimeError as e: - logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - - continue - - if module is None: - continue - - logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") - extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 - - self.network_current_names = wanted_names + pass def network_forward(org_module, input, original_forward): - """ - Old way of applying Lora by executing operations during layer's forward. - Stacking many loras this way results in big performance degradation. - """ - - if len(loaded_networks) == 0: - return original_forward(org_module, input) - - input = devices.cond_cast_unet(input) - - network_restore_weights_from_backup(org_module) - network_reset_cached_weight(org_module) - - y = original_forward(org_module, input) - - network_layer_name = getattr(org_module, 'network_layer_name', None) - for lora in loaded_networks: - module = lora.modules.get(network_layer_name, None) - if module is None: - continue - - y = module.forward(input, y) - - return y + pass def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): - self.network_current_names = () - self.network_weights_backup = None - self.network_bias_backup = None + pass def network_Linear_forward(self, input): - if shared.opts.lora_functional: - return network_forward(self, input, originals.Linear_forward) - - network_apply_weights(self) - - return originals.Linear_forward(self, input) + pass def network_Linear_load_state_dict(self, *args, **kwargs): - network_reset_cached_weight(self) - - return originals.Linear_load_state_dict(self, *args, **kwargs) + pass def network_Conv2d_forward(self, input): - if shared.opts.lora_functional: - return network_forward(self, input, originals.Conv2d_forward) - - network_apply_weights(self) - - return originals.Conv2d_forward(self, input) + pass def network_Conv2d_load_state_dict(self, *args, **kwargs): - network_reset_cached_weight(self) - - return originals.Conv2d_load_state_dict(self, *args, **kwargs) + pass def network_GroupNorm_forward(self, input): - if shared.opts.lora_functional: - return network_forward(self, input, originals.GroupNorm_forward) - - network_apply_weights(self) - - return originals.GroupNorm_forward(self, input) + pass def network_GroupNorm_load_state_dict(self, *args, **kwargs): - network_reset_cached_weight(self) - - return originals.GroupNorm_load_state_dict(self, *args, **kwargs) + pass def network_LayerNorm_forward(self, input): - if shared.opts.lora_functional: - return network_forward(self, input, originals.LayerNorm_forward) - - network_apply_weights(self) - - return originals.LayerNorm_forward(self, input) + pass def network_LayerNorm_load_state_dict(self, *args, **kwargs): - network_reset_cached_weight(self) - - return originals.LayerNorm_load_state_dict(self, *args, **kwargs) + pass def network_MultiheadAttention_forward(self, *args, **kwargs): - network_apply_weights(self) - - return originals.MultiheadAttention_forward(self, *args, **kwargs) + pass def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): - network_reset_cached_weight(self) - - return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) + pass def list_available_networks():