i
This commit is contained in:
parent
a4d743e5f8
commit
20f1fb6c0b
@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
|
|||||||
|
|
||||||
|
|
||||||
class DictWithShape(dict):
|
class DictWithShape(dict):
|
||||||
def __init__(self, x, shape):
|
def __init__(self, x):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.update(x)
|
self.update(x)
|
||||||
|
|
||||||
@ -282,6 +282,13 @@ class DictWithShape(dict):
|
|||||||
self[k] = self[k].to(*args, **kwargs)
|
self[k] = self[k].to(*args, **kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def advanced_indexing(self, item):
|
||||||
|
result = {}
|
||||||
|
for k in self.keys():
|
||||||
|
if isinstance(self[k], torch.Tensor):
|
||||||
|
result[k] = self[k][item]
|
||||||
|
return DictWithShape(result)
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||||
param = c[0][0].cond
|
param = c[0][0].cond
|
||||||
@ -290,7 +297,7 @@ def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_s
|
|||||||
if is_dict:
|
if is_dict:
|
||||||
dict_cond = param
|
dict_cond = param
|
||||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
res = DictWithShape(res)
|
||||||
else:
|
else:
|
||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
@ -348,7 +355,7 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
if isinstance(tensors[0], dict):
|
if isinstance(tensors[0], dict):
|
||||||
keys = list(tensors[0].keys())
|
keys = list(tensors[0].keys())
|
||||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
stacked = DictWithShape(stacked)
|
||||||
else:
|
else:
|
||||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
|
@ -38,7 +38,12 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights):
|
|||||||
for i, w in cond_pre:
|
for i, w in cond_pre:
|
||||||
current_indices.append(i)
|
current_indices.append(i)
|
||||||
current_weight = w
|
current_weight = w
|
||||||
feed = cond[current_indices]
|
|
||||||
|
if hasattr(cond, 'advanced_indexing'):
|
||||||
|
feed = cond.advanced_indexing(current_indices)
|
||||||
|
else:
|
||||||
|
feed = cond[current_indices]
|
||||||
|
|
||||||
h = cond_from_a1111_to_patched_ldm(feed)
|
h = cond_from_a1111_to_patched_ldm(feed)
|
||||||
h[0]['strength'] = current_weight
|
h[0]['strength'] = current_weight
|
||||||
results += h
|
results += h
|
||||||
|
Loading…
Reference in New Issue
Block a user