Update forge_loader.py
This commit is contained in:
parent
f313726201
commit
41da08696a
@ -18,7 +18,7 @@ import open_clip
|
|||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
|
||||||
class FakeObject(torch.nn.Module):
|
class FakeObject:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.visual = None
|
self.visual = None
|
||||||
@ -145,6 +145,36 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
|
|||||||
sd_model.first_stage_model = forge_object.vae.first_stage_model
|
sd_model.first_stage_model = forge_object.vae.first_stage_model
|
||||||
sd_model.model.diffusion_model = forge_object.unet.model.diffusion_model
|
sd_model.model.diffusion_model = forge_object.unet.model.diffusion_model
|
||||||
|
|
||||||
|
conditioner = getattr(sd_model, 'conditioner', None)
|
||||||
|
if conditioner:
|
||||||
|
text_cond_models = []
|
||||||
|
|
||||||
|
for i in range(len(conditioner.embedders)):
|
||||||
|
embedder = conditioner.embedders[i]
|
||||||
|
typename = type(embedder).__name__
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder':
|
||||||
|
pass
|
||||||
|
text_cond_models.append(embedder)
|
||||||
|
if typename == 'FrozenCLIPEmbedder':
|
||||||
|
embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
|
||||||
|
embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
|
||||||
|
text_cond_models.append(embedder)
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
|
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
|
||||||
|
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
|
||||||
|
text_cond_models.append(embedder)
|
||||||
|
|
||||||
|
if len(text_cond_models) == 1:
|
||||||
|
sd_model.cond_stage_model = text_cond_models[0]
|
||||||
|
else:
|
||||||
|
sd_model.cond_stage_model = conditioner
|
||||||
|
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder':
|
||||||
|
a = 0
|
||||||
|
pass
|
||||||
|
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder':
|
||||||
|
a = 0
|
||||||
|
pass
|
||||||
|
|
||||||
timer.record("forge set components")
|
timer.record("forge set components")
|
||||||
|
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user