This commit is contained in:
lllyasviel 2024-01-31 23:21:58 -08:00
parent 0f7c71b400
commit a203113f43
3 changed files with 66 additions and 71 deletions

View File

@ -11,9 +11,6 @@ import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
compute_controlnet_weighting = None
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
@ -32,6 +29,71 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
def get_at(array, index, default=None):
return array[index] if 0 <= index < len(array) else default
def compute_controlnet_weighting(control, cnet):
positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None)
negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None)
advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None)
advanced_sigma_weighting = getattr(cnet, 'advanced_sigma_weighting', None)
advanced_mask_weighting = getattr(cnet, 'advanced_mask_weighting', None)
transformer_options = cnet.transformer_options
if positive_advanced_weighting is None and negative_advanced_weighting is None \
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
and advanced_mask_weighting is None:
return control
cond_or_uncond = transformer_options['cond_or_uncond']
sigmas = transformer_options['sigmas']
cond_mark = transformer_options['cond_mark']
if advanced_frame_weighting is not None:
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
'Frame weighting list length is different from batch size!'
if advanced_sigma_weighting is not None:
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
for k, v in control.items():
for i in range(len(v)):
control_signal = control[k][i]
B, C, H, W = control_signal.shape
positive_weight = 1.0
negative_weight = 1.0
sigma_weight = 1.0
frame_weight = 1.0
if positive_advanced_weighting is not None:
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
if negative_advanced_weighting is not None:
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
if advanced_sigma_weighting is not None:
sigma_weight = advanced_sigma_weighting
if advanced_frame_weighting is not None:
frame_weight = advanced_frame_weighting
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
final_weight = final_weight * sigma_weight * frame_weight
if isinstance(advanced_mask_weighting, torch.Tensor):
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]
return control
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
@ -119,7 +181,6 @@ class ControlBase:
out[key].append(x)
if compute_controlnet_weighting is not None:
out = compute_controlnet_weighting(out, self)
if control_prev is not None:

View File

@ -1,10 +1,6 @@
import torch
def get_at(array, index, default=None):
return array[index] if 0 <= index < len(array) else default
def apply_controlnet_advanced(
unet,
controlnet,
@ -79,61 +75,3 @@ def apply_controlnet_advanced(
m.add_patched_controlnet(cnet)
return m
def compute_controlnet_weighting(control, cnet):
positive_advanced_weighting = cnet.positive_advanced_weighting
negative_advanced_weighting = cnet.negative_advanced_weighting
advanced_frame_weighting = cnet.advanced_frame_weighting
advanced_sigma_weighting = cnet.advanced_sigma_weighting
advanced_mask_weighting = cnet.advanced_mask_weighting
transformer_options = cnet.transformer_options
if positive_advanced_weighting is None and negative_advanced_weighting is None \
and advanced_frame_weighting is None and advanced_sigma_weighting is None \
and advanced_mask_weighting is None:
return control
cond_or_uncond = transformer_options['cond_or_uncond']
sigmas = transformer_options['sigmas']
cond_mark = transformer_options['cond_mark']
if advanced_frame_weighting is not None:
advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas)
assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \
'Frame weighting list length is different from batch size!'
if advanced_sigma_weighting is not None:
advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond))
for k, v in control.items():
for i in range(len(v)):
control_signal = control[k][i]
B, C, H, W = control_signal.shape
positive_weight = 1.0
negative_weight = 1.0
sigma_weight = 1.0
frame_weight = 1.0
if positive_advanced_weighting is not None:
positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0)
if negative_advanced_weighting is not None:
negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0)
if advanced_sigma_weighting is not None:
sigma_weight = advanced_sigma_weighting
if advanced_frame_weighting is not None:
frame_weight = advanced_frame_weighting
final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark
final_weight = final_weight * sigma_weight * frame_weight
if isinstance(advanced_mask_weighting, torch.Tensor):
control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear')
control[k][i] = control_signal * final_weight[:, None, None, None]
return control

View File

@ -38,10 +38,6 @@ def build_loaded(module, loader_name):
def patch_all_basics():
import ldm_patched.modules.controlnet
import modules_forge.controlnet
ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')
return