Update forge_loader.py

This commit is contained in:
lllyasviel 2024-01-25 04:22:40 -08:00
parent f313726201
commit 41da08696a

View File

@ -18,7 +18,7 @@ import open_clip
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
class FakeObject(torch.nn.Module): class FakeObject:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.visual = None self.visual = None
@ -145,6 +145,36 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
sd_model.first_stage_model = forge_object.vae.first_stage_model sd_model.first_stage_model = forge_object.vae.first_stage_model
sd_model.model.diffusion_model = forge_object.unet.model.diffusion_model sd_model.model.diffusion_model = forge_object.unet.model.diffusion_model
conditioner = getattr(sd_model, 'conditioner', None)
if conditioner:
text_cond_models = []
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
pass
text_cond_models.append(embedder)
if typename == 'FrozenCLIPEmbedder':
embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
text_cond_models.append(embedder)
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
text_cond_models.append(embedder)
if len(text_cond_models) == 1:
sd_model.cond_stage_model = text_cond_models[0]
else:
sd_model.cond_stage_model = conditioner
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder':
a = 0
pass
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder':
a = 0
pass
timer.record("forge set components") timer.record("forge set components")
return return