my-sd/modules_forge/controlnet.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

82 lines
3.0 KiB
Python
Raw Normal View History

2024-01-28 03:52:31 +00:00
def apply_controlnet_advanced(
unet,
controlnet,
2024-01-28 06:44:49 +00:00
image_bhwc,
2024-01-28 03:52:31 +00:00
strength,
start_percent,
end_percent,
positive_advanced_weighting=None,
2024-01-28 07:47:17 +00:00
negative_advanced_weighting=None,
2024-01-28 08:26:49 +00:00
advanced_frame_weighting=None,
2024-01-28 15:27:51 +00:00
advanced_sigma_weighting=None
2024-01-28 07:47:17 +00:00
):
2024-01-28 15:02:58 +00:00
"""
# positive_advanced_weighting or negative_advanced_weighting
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks.
Below is an example for stronger control in middle block.
This is helpful for some high-res fix passes.
2024-01-28 16:03:36 +00:00
positive_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
negative_advanced_weighting = {
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
'middle': [1.0],
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
}
2024-01-28 15:02:58 +00:00
# advanced_frame_weighting
The advanced_frame_weighting is a weight applied to each image in a batch.
The length of this list must be same with batch size
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0]
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time.
2024-01-28 15:27:51 +00:00
# advanced_sigma_weighting
The advanced_sigma_weighting allows you to dynamically compute control
weights given diffusion timestep (sigma).
For example below code can softly make beginning steps stronger than ending steps.
2024-01-28 16:09:28 +00:00
sigma_max = unet.model.model_sampling.sigma_max
sigma_min = unet.model.model_sampling.sigma_min
2024-01-28 16:03:36 +00:00
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
2024-01-28 15:27:51 +00:00
2024-01-28 15:02:58 +00:00
"""
2024-01-28 04:42:38 +00:00
2024-01-28 06:44:49 +00:00
cnet = controlnet.copy().set_cond_hint(image_bhwc.movedim(-1, 1), strength, (start_percent, end_percent))
2024-01-28 05:26:26 +00:00
cnet.positive_advanced_weighting = positive_advanced_weighting
cnet.negative_advanced_weighting = negative_advanced_weighting
2024-01-28 08:26:49 +00:00
cnet.advanced_frame_weighting = advanced_frame_weighting
2024-01-28 15:27:51 +00:00
cnet.advanced_sigma_weighting = advanced_sigma_weighting
2024-01-28 04:42:38 +00:00
2024-01-28 03:53:42 +00:00
m = unet.clone()
2024-01-28 05:26:26 +00:00
m.add_patched_controlnet(cnet)
2024-01-28 03:53:42 +00:00
return m
2024-01-28 15:56:42 +00:00
def compute_controlnet_weighting(
control,
positive_advanced_weighting,
negative_advanced_weighting,
advanced_frame_weighting,
advanced_sigma_weighting,
transformer_options
):
2024-01-28 16:03:36 +00:00
if positive_advanced_weighting is None and negative_advanced_weighting is None \
and advanced_frame_weighting is None and advanced_sigma_weighting is None:
return control
cond_or_uncond = transformer_options['cond_or_uncond']
sigmas = transformer_options['sigmas']
2024-01-28 16:27:43 +00:00
cond_mark = transformer_options['cond_mark']
2024-01-28 16:03:36 +00:00
if advanced_sigma_weighting is not None:
advanced_sigma_weighting = advanced_sigma_weighting(sigmas)
2024-01-28 15:56:42 +00:00
return control