From cc2a88355a037784e56943645c0d915ec5f24384 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 06:09:01 -0800 Subject: [PATCH] Update sd_hijack_open_clip.py --- modules/sd_hijack_open_clip.py | 64 +++++++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index c690b5e7..25c5e983 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -1,15 +1,71 @@ +import open_clip.tokenizer import torch -from modules import sd_hijack_clip +from modules import sd_hijack_clip, devices from modules.shared import opts +tokenizer = open_clip.tokenizer._tokenizer -class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): + +class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 -class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords): + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers + z = self.wrapped.encode_with_transformer(tokens) + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) - a = 0 + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) + + return embedded