From 41da08696a0c8f75865a6f081e2d7521daf69f7b Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 04:22:40 -0800 Subject: [PATCH] Update forge_loader.py --- modules_forge/forge_loader.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 225f8d1a..9f1570fb 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -18,7 +18,7 @@ import open_clip from transformers import CLIPTextModel, CLIPTokenizer -class FakeObject(torch.nn.Module): +class FakeObject: def __init__(self, *args, **kwargs): super().__init__() self.visual = None @@ -145,6 +145,36 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.first_stage_model = forge_object.vae.first_stage_model sd_model.model.diffusion_model = forge_object.unet.model.diffusion_model + conditioner = getattr(sd_model, 'conditioner', None) + if conditioner: + text_cond_models = [] + + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': + pass + text_cond_models.append(embedder) + if typename == 'FrozenCLIPEmbedder': + embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer + embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer + text_cond_models.append(embedder) + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer + embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer + text_cond_models.append(embedder) + + if len(text_cond_models) == 1: + sd_model.cond_stage_model = text_cond_models[0] + else: + sd_model.cond_stage_model = conditioner + elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder': + a = 0 + pass + elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder': + a = 0 + pass + timer.record("forge set components") return