diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 0fc90118..b9879235 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -11,9 +11,6 @@ import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter -compute_controlnet_weighting = None - - def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -32,6 +29,71 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) + +def get_at(array, index, default=None): + return array[index] if 0 <= index < len(array) else default + + +def compute_controlnet_weighting(control, cnet): + + positive_advanced_weighting = getattr(cnet, 'positive_advanced_weighting', None) + negative_advanced_weighting = getattr(cnet, 'negative_advanced_weighting', None) + advanced_frame_weighting = getattr(cnet, 'advanced_frame_weighting', None) + advanced_sigma_weighting = getattr(cnet, 'advanced_sigma_weighting', None) + advanced_mask_weighting = getattr(cnet, 'advanced_mask_weighting', None) + + transformer_options = cnet.transformer_options + + if positive_advanced_weighting is None and negative_advanced_weighting is None \ + and advanced_frame_weighting is None and advanced_sigma_weighting is None \ + and advanced_mask_weighting is None: + return control + + cond_or_uncond = transformer_options['cond_or_uncond'] + sigmas = transformer_options['sigmas'] + cond_mark = transformer_options['cond_mark'] + + if advanced_frame_weighting is not None: + advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas) + assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \ + 'Frame weighting list length is different from batch size!' + + if advanced_sigma_weighting is not None: + advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond)) + + for k, v in control.items(): + for i in range(len(v)): + control_signal = control[k][i] + B, C, H, W = control_signal.shape + + positive_weight = 1.0 + negative_weight = 1.0 + sigma_weight = 1.0 + frame_weight = 1.0 + + if positive_advanced_weighting is not None: + positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0) + + if negative_advanced_weighting is not None: + negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0) + + if advanced_sigma_weighting is not None: + sigma_weight = advanced_sigma_weighting + + if advanced_frame_weighting is not None: + frame_weight = advanced_frame_weighting + + final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark + final_weight = final_weight * sigma_weight * frame_weight + + if isinstance(advanced_mask_weighting, torch.Tensor): + control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear') + + control[k][i] = control_signal * final_weight[:, None, None, None] + + return control + + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -119,8 +181,7 @@ class ControlBase: out[key].append(x) - if compute_controlnet_weighting is not None: - out = compute_controlnet_weighting(out, self) + out = compute_controlnet_weighting(out, self) if control_prev is not None: for x in ['input', 'middle', 'output']: diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index b0ca32d1..e8585c53 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -1,10 +1,6 @@ import torch -def get_at(array, index, default=None): - return array[index] if 0 <= index < len(array) else default - - def apply_controlnet_advanced( unet, controlnet, @@ -79,61 +75,3 @@ def apply_controlnet_advanced( m.add_patched_controlnet(cnet) return m - -def compute_controlnet_weighting(control, cnet): - - positive_advanced_weighting = cnet.positive_advanced_weighting - negative_advanced_weighting = cnet.negative_advanced_weighting - advanced_frame_weighting = cnet.advanced_frame_weighting - advanced_sigma_weighting = cnet.advanced_sigma_weighting - advanced_mask_weighting = cnet.advanced_mask_weighting - transformer_options = cnet.transformer_options - - if positive_advanced_weighting is None and negative_advanced_weighting is None \ - and advanced_frame_weighting is None and advanced_sigma_weighting is None \ - and advanced_mask_weighting is None: - return control - - cond_or_uncond = transformer_options['cond_or_uncond'] - sigmas = transformer_options['sigmas'] - cond_mark = transformer_options['cond_mark'] - - if advanced_frame_weighting is not None: - advanced_frame_weighting = torch.Tensor(advanced_frame_weighting * len(cond_or_uncond)).to(sigmas) - assert advanced_frame_weighting.shape[0] == cond_mark.shape[0], \ - 'Frame weighting list length is different from batch size!' - - if advanced_sigma_weighting is not None: - advanced_sigma_weighting = torch.cat([advanced_sigma_weighting(sigmas)] * len(cond_or_uncond)) - - for k, v in control.items(): - for i in range(len(v)): - control_signal = control[k][i] - B, C, H, W = control_signal.shape - - positive_weight = 1.0 - negative_weight = 1.0 - sigma_weight = 1.0 - frame_weight = 1.0 - - if positive_advanced_weighting is not None: - positive_weight = get_at(positive_advanced_weighting.get(k, []), i, 1.0) - - if negative_advanced_weighting is not None: - negative_weight = get_at(negative_advanced_weighting.get(k, []), i, 1.0) - - if advanced_sigma_weighting is not None: - sigma_weight = advanced_sigma_weighting - - if advanced_frame_weighting is not None: - frame_weight = advanced_frame_weighting - - final_weight = positive_weight * (1.0 - cond_mark) + negative_weight * cond_mark - final_weight = final_weight * sigma_weight * frame_weight - - if isinstance(advanced_mask_weighting, torch.Tensor): - control_signal = control_signal * torch.nn.functional.interpolate(advanced_mask_weighting, size=(H, W), mode='bilinear') - - control[k][i] = control_signal * final_weight[:, None, None, None] - - return control diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index eab64573..5b61819a 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -38,10 +38,6 @@ def build_loaded(module, loader_name): def patch_all_basics(): - import ldm_patched.modules.controlnet - import modules_forge.controlnet - - ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting build_loaded(safetensors.torch, 'load_file') build_loaded(torch, 'load') return