From 8059533eaff951ac6a1f24e179ae57296d2b9411 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 9 Feb 2024 13:50:50 -0800 Subject: [PATCH] fix lora not loaded for text encoder #142 --- modules/sd_models_xl.py | 3 ++- modules_forge/forge_clip.py | 14 ++++++++++++-- modules_forge/forge_loader.py | 4 ---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 30cd9c23..9ea8d690 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -9,10 +9,11 @@ from modules import devices, shared, prompt_parser from modules import torch_utils import ldm_patched.modules.model_management as model_management +from modules_forge.forge_clip import move_clip_to_gpu def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): - model_management.load_model_gpu(self.forge_objects.clip.patcher) + move_clip_to_gpu() for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 diff --git a/modules_forge/forge_clip.py b/modules_forge/forge_clip.py index 96936e8c..a12f4d20 100644 --- a/modules_forge/forge_clip.py +++ b/modules_forge/forge_clip.py @@ -1,11 +1,21 @@ from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords from ldm_patched.modules import model_management +from modules import sd_models from modules.shared import opts +def move_clip_to_gpu(): + if sd_models.model_data.sd_model is None: + print('Error: CLIP called before SD is loaded!') + return + + model_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher) + return + + class CLIP_SD_15_L(FrozenCLIPEmbedderWithCustomWords): def encode_with_transformers(self, tokens): - model_management.load_model_gpu(self.forge_objects.clip.patcher) + move_clip_to_gpu() self.wrapped.transformer.text_model.embeddings.to(tokens.device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -31,7 +41,7 @@ class CLIP_SD_21_H(FrozenCLIPEmbedderWithCustomWords): self.id_pad = 0 def encode_with_transformers(self, tokens): - model_management.load_model_gpu(self.forge_objects.clip.patcher) + move_clip_to_gpu() self.wrapped.transformer.text_model.embeddings.to(tokens.device) outputs = self.wrapped.transformer(tokens, output_hidden_states=self.wrapped.layer == "hidden") diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index d662f8dd..7b283958 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -180,7 +180,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) embedder = forge_clip.CLIP_SD_XL_L(embedder, sd_hijack.model_hijack) - embedder.forge_objects = forge_objects conditioner.embedders[i] = embedder text_cond_models.append(embedder) elif typename == 'FrozenOpenCLIPEmbedder2': # SDXL Clip G @@ -191,7 +190,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack, textual_inversion_key='clip_g') embedder = forge_clip.CLIP_SD_XL_G(embedder, sd_hijack.model_hijack) - embedder.forge_objects = forge_objects conditioner.embedders[i] = embedder text_cond_models.append(embedder) @@ -206,7 +204,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) sd_model.cond_stage_model = forge_clip.CLIP_SD_15_L(sd_model.cond_stage_model, sd_hijack.model_hijack) - sd_model.cond_stage_model.forge_objects = forge_objects elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': # SD21 Clip sd_model.cond_stage_model.tokenizer = forge_objects.clip.tokenizer.clip_h.tokenizer sd_model.cond_stage_model.transformer = forge_objects.clip.cond_stage_model.clip_h.transformer @@ -214,7 +211,6 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): model_embeddings.token_embedding = sd_hijack.EmbeddingsWithFixes( model_embeddings.token_embedding, sd_hijack.model_hijack) sd_model.cond_stage_model = forge_clip.CLIP_SD_21_H(sd_model.cond_stage_model, sd_hijack.model_hijack) - sd_model.cond_stage_model.forge_objects = forge_objects else: raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)