qol
This commit is contained in:
parent
0f7c71b400
commit
a203113f43
@ -11,9 +11,6 @@ import ldm_patched.controlnet.cldm
|
|||||||
import ldm_patched.t2ia.adapter
|
import ldm_patched.t2ia.adapter
|
||||||
|
|
||||||
|
|
||||||
compute_controlnet_weighting = None
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
#print(current_batch_size, target_batch_size)
|
||||||
@ -32,6 +29,71 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_at(array, index, default=None):
|
||||||
|
return array[index] if 0 <= index < len(array) else default
|
||||||
|
|
||||||
|
|
||||||
|
def compute_controlnet_weighting(control, cnet):
|
||||||
|
|
||||||
|
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
|
||||||
|
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)
|
||||||
|
advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None)
|
||||||
|
advanced_sigma_weighting = getattr(cnet, 'advanced_sigma_weighting', None)
|
||||||
|
advanced_mask_weighting = getattr(cnet, 'advanced_mask_weighting', None)
|
||||||
|
|
||||||
|
transformer_options = cnet.transformer_options
|
||||||
|
|
||||||
|
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
||||||
|
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
|
||||||
|
and advanced_mask_weighting is None:
|
||||||
|
return control
|
||||||
|
|
||||||
|
cond_or_uncond = transformer_options['cond_or_uncond']
|
||||||
|
sigmas = transformer_options['sigmas']
|
||||||
|
cond_mark = transformer_options['cond_mark']
|
||||||
|
|
||||||
|
if advanced_frame_weighting is not None:
|
||||||
|
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
|
||||||
|
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
|
||||||
|
'Frame weighting list length is different from batch size!'
|
||||||
|
|
||||||
|
if advanced_sigma_weighting is not None:
|
||||||
|
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
|
||||||
|
|
||||||
|
for k, v in control.items():
|
||||||
|
for i in range(len(v)):
|
||||||
|
control_signal = control[k][i]
|
||||||
|
B, C, H, W = control_signal.shape
|
||||||
|
|
||||||
|
positive_weight = 1.0
|
||||||
|
negative_weight = 1.0
|
||||||
|
sigma_weight = 1.0
|
||||||
|
frame_weight = 1.0
|
||||||
|
|
||||||
|
if positive_advanced_weighting is not None:
|
||||||
|
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
|
||||||
|
|
||||||
|
if negative_advanced_weighting is not None:
|
||||||
|
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
|
||||||
|
|
||||||
|
if advanced_sigma_weighting is not None:
|
||||||
|
sigma_weight = advanced_sigma_weighting
|
||||||
|
|
||||||
|
if advanced_frame_weighting is not None:
|
||||||
|
frame_weight = advanced_frame_weighting
|
||||||
|
|
||||||
|
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
|
||||||
|
final_weight = final_weight * sigma_weight * frame_weight
|
||||||
|
|
||||||
|
if isinstance(advanced_mask_weighting, torch.Tensor):
|
||||||
|
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
|
||||||
|
|
||||||
|
control[k][i] = control_signal * final_weight[:, None, None, None]
|
||||||
|
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@ -119,8 +181,7 @@ class ControlBase:
|
|||||||
|
|
||||||
out[key].append(x)
|
out[key].append(x)
|
||||||
|
|
||||||
if compute_controlnet_weighting is not None:
|
out = compute_controlnet_weighting(out, self)
|
||||||
out = compute_controlnet_weighting(out, self)
|
|
||||||
|
|
||||||
if control_prev is not None:
|
if control_prev is not None:
|
||||||
for x in ['input', 'middle', 'output']:
|
for x in ['input', 'middle', 'output']:
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_at(array, index, default=None):
|
|
||||||
return array[index] if 0 <= index < len(array) else default
|
|
||||||
|
|
||||||
|
|
||||||
def apply_controlnet_advanced(
|
def apply_controlnet_advanced(
|
||||||
unet,
|
unet,
|
||||||
controlnet,
|
controlnet,
|
||||||
@ -79,61 +75,3 @@ def apply_controlnet_advanced(
|
|||||||
m.add_patched_controlnet(cnet)
|
m.add_patched_controlnet(cnet)
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
def compute_controlnet_weighting(control, cnet):
|
|
||||||
|
|
||||||
positive_advanced_weighting = cnet.positive_advanced_weighting
|
|
||||||
negative_advanced_weighting = cnet.negative_advanced_weighting
|
|
||||||
advanced_frame_weighting = cnet.advanced_frame_weighting
|
|
||||||
advanced_sigma_weighting = cnet.advanced_sigma_weighting
|
|
||||||
advanced_mask_weighting = cnet.advanced_mask_weighting
|
|
||||||
transformer_options = cnet.transformer_options
|
|
||||||
|
|
||||||
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
|
||||||
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
|
|
||||||
and advanced_mask_weighting is None:
|
|
||||||
return control
|
|
||||||
|
|
||||||
cond_or_uncond = transformer_options['cond_or_uncond']
|
|
||||||
sigmas = transformer_options['sigmas']
|
|
||||||
cond_mark = transformer_options['cond_mark']
|
|
||||||
|
|
||||||
if advanced_frame_weighting is not None:
|
|
||||||
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
|
|
||||||
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
|
|
||||||
'Frame weighting list length is different from batch size!'
|
|
||||||
|
|
||||||
if advanced_sigma_weighting is not None:
|
|
||||||
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
|
|
||||||
|
|
||||||
for k, v in control.items():
|
|
||||||
for i in range(len(v)):
|
|
||||||
control_signal = control[k][i]
|
|
||||||
B, C, H, W = control_signal.shape
|
|
||||||
|
|
||||||
positive_weight = 1.0
|
|
||||||
negative_weight = 1.0
|
|
||||||
sigma_weight = 1.0
|
|
||||||
frame_weight = 1.0
|
|
||||||
|
|
||||||
if positive_advanced_weighting is not None:
|
|
||||||
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
|
|
||||||
|
|
||||||
if negative_advanced_weighting is not None:
|
|
||||||
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
|
|
||||||
|
|
||||||
if advanced_sigma_weighting is not None:
|
|
||||||
sigma_weight = advanced_sigma_weighting
|
|
||||||
|
|
||||||
if advanced_frame_weighting is not None:
|
|
||||||
frame_weight = advanced_frame_weighting
|
|
||||||
|
|
||||||
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
|
|
||||||
final_weight = final_weight * sigma_weight * frame_weight
|
|
||||||
|
|
||||||
if isinstance(advanced_mask_weighting, torch.Tensor):
|
|
||||||
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
|
|
||||||
|
|
||||||
control[k][i] = control_signal * final_weight[:, None, None, None]
|
|
||||||
|
|
||||||
return control
|
|
||||||
|
@ -38,10 +38,6 @@ def build_loaded(module, loader_name):
|
|||||||
|
|
||||||
|
|
||||||
def patch_all_basics():
|
def patch_all_basics():
|
||||||
import ldm_patched.modules.controlnet
|
|
||||||
import modules_forge.controlnet
|
|
||||||
|
|
||||||
ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting
|
|
||||||
build_loaded(safetensors.torch, 'load_file')
|
build_loaded(safetensors.torch, 'load_file')
|
||||||
build_loaded(torch, 'load')
|
build_loaded(torch, 'load')
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user