i
This commit is contained in:
parent
a4d743e5f8
commit
20f1fb6c0b
@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape):
|
||||
def __init__(self, x):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@ -282,6 +282,13 @@ class DictWithShape(dict):
|
||||
self[k] = self[k].to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def advanced_indexing(self, item):
|
||||
result = {}
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
result[k] = self[k][item]
|
||||
return DictWithShape(result)
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
@ -290,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||
res = DictWithShape(res)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
@ -348,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
stacked = DictWithShape(stacked)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
|
@ -38,7 +38,12 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
|
||||
for i, w in cond_pre:
|
||||
current_indices.append(i)
|
||||
current_weight = w
|
||||
feed = cond[current_indices]
|
||||
|
||||
if hasattr(cond, 'advanced_indexing'):
|
||||
feed = cond.advanced_indexing(current_indices)
|
||||
else:
|
||||
feed = cond[current_indices]
|
||||
|
||||
h = cond_from_a1111_to_patched_ldm(feed)
|
||||
h[0]['strength'] = current_weight
|
||||
results += h
|
||||
|
Loading…
Reference in New Issue
Block a user