From 20f1fb6c0bb9a20287af4e27b82be37969e492c0 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 18:53:24 -0800 Subject: [PATCH] i --- modules/prompt_parser.py | 13 ++++++++++--- modules_forge/forge_sampler.py | 7 ++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 2afd1b7c..c8b423a0 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, class DictWithShape(dict): - def __init__(self, x, shape): + def __init__(self, x): super().__init__() self.update(x) @@ -282,6 +282,13 @@ class DictWithShape(dict): self[k] = self[k].to(*args, **kwargs) return self + def advanced_indexing(self, item): + result = {} + for k in self.keys(): + if isinstance(self[k], torch.Tensor): + result[k] = self[k][item] + return DictWithShape(result) + def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step): param = c[0][0].cond @@ -290,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s if is_dict: dict_cond = param res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} - res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + res = DictWithShape(res) else: res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) @@ -348,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): if isinstance(tensors[0], dict): keys = list(tensors[0].keys()) stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} - stacked = DictWithShape(stacked, stacked['crossattn'].shape) + stacked = DictWithShape(stacked) else: stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index 1f92fc28..7028ef62 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -38,7 +38,12 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights): for i, w in cond_pre: current_indices.append(i) current_weight = w - feed = cond[current_indices] + + if hasattr(cond, 'advanced_indexing'): + feed = cond.advanced_indexing(current_indices) + else: + feed = cond[current_indices] + h = cond_from_a1111_to_patched_ldm(feed) h[0]['strength'] = current_weight results += h