This commit is contained in:
lllyasviel 2024-01-27 18:53:24 -08:00
parent a4d743e5f8
commit 20f1fb6c0b
2 changed files with 16 additions and 4 deletions

View File

@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class DictWithShape(dict):
def __init__(self, x, shape):
def __init__(self, x):
super().__init__()
self.update(x)
@ -282,6 +282,13 @@ class DictWithShape(dict):
self[k] = self[k].to(*args, **kwargs)
return self
def advanced_indexing(self, item):
result = {}
for k in self.keys():
if isinstance(self[k], torch.Tensor):
result[k] = self[k][item]
return DictWithShape(result)
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
@ -290,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
if is_dict:
dict_cond = param
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
res = DictWithShape(res)
else:
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
@ -348,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
if isinstance(tensors[0], dict):
keys = list(tensors[0].keys())
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
stacked = DictWithShape(stacked)
else:
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)

View File

@ -38,7 +38,12 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
for i, w in cond_pre:
current_indices.append(i)
current_weight = w
feed = cond[current_indices]
if hasattr(cond, 'advanced_indexing'):
feed = cond.advanced_indexing(current_indices)
else:
feed = cond[current_indices]
h = cond_from_a1111_to_patched_ldm(feed)
h[0]['strength'] = current_weight
results += h