This commit is contained in:
lllyasviel 2024-01-25 05:47:54 -08:00
parent feac0d7f2d
commit 4337710c4a
2 changed files with 31 additions and 55 deletions

View File

@ -142,57 +142,7 @@ class StableDiffusionModelHijack:
pass
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
text_cond_models = []
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenCLIPEmbedder':
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if len(text_cond_models) == 1:
m.cond_stage_model = text_cond_models[0]
else:
m.cond_stage_model = conditioner
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
apply_weighted_forward(m)
def flatten(el):
flattened = [flatten(children) for children in el.children()]
res = [el]
for c in flattened:
res += c
return res
self.layers = flatten(m)
sd_unet.original_forward = None
pass
def undo_hijack(self, m):
pass

View File

@ -12,7 +12,9 @@ import ldm_patched.modules.clip_vision
from omegaconf import OmegaConf
from modules.sd_models_config import find_checkpoint_config
from modules.shared import cmd_opts
import modules.sd_hijack as sd_hijack
from modules import sd_hijack
from modules.sd_hijack import EmbeddingsWithFixes
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config
@ -163,14 +165,33 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
if typename == 'FrozenOpenCLIPEmbedder':
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack)
embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
elif typename == 'FrozenCLIPEmbedder':
embedder.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack)
embedder = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
elif typename == 'FrozenOpenCLIPEmbedder2':
embedder.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
embedder.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding,
sd_hijack.model_hijack,
textual_inversion_key='clip_g')
embedder = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, sd_hijack.model_hijack)
conditioner.embedders[i] = embedder
text_cond_models.append(embedder)
if len(text_cond_models) == 1:
@ -180,17 +201,22 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None):
elif type(sd_model.cond_stage_model).__name__ == 'FrozenCLIPEmbedder':
sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_l.tokenizer
sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_l.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model,
sd_hijack.model_hijack)
elif type(sd_model.cond_stage_model).__name__ == 'FrozenOpenCLIPEmbedder':
sd_model.cond_stage_model.tokenizer = forge_object.clip.tokenizer.clip_g.tokenizer
sd_model.cond_stage_model.transformer = forge_object.clip.cond_stage_model.clip_g.transformer
model_embeddings = sd_model.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, sd_hijack.model_hijack)
sd_model.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(sd_model.cond_stage_model,
sd_hijack.model_hijack)
else:
raise NotImplementedError('Bad Clip Class Name:' + type(sd_model.cond_stage_model).__name__)
timer.record("forge set components")
sd_hijack.model_hijack.hijack(sd_model)
timer.record("forge hijack")
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")