2024-01-28 03:52:31 +00:00
|
|
|
def apply_controlnet_advanced(
|
|
|
|
unet,
|
|
|
|
controlnet,
|
2024-01-28 06:44:49 +00:00
|
|
|
image_bhwc,
|
2024-01-28 03:52:31 +00:00
|
|
|
strength,
|
|
|
|
start_percent,
|
|
|
|
end_percent,
|
|
|
|
positive_advanced_weighting=None,
|
2024-01-28 07:47:17 +00:00
|
|
|
negative_advanced_weighting=None,
|
2024-01-28 08:26:49 +00:00
|
|
|
advanced_frame_weighting=None,
|
2024-01-28 15:27:51 +00:00
|
|
|
advanced_sigma_weighting=None
|
2024-01-28 07:47:17 +00:00
|
|
|
):
|
2024-01-28 15:02:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# positive_advanced_weighting or negative_advanced_weighting
|
|
|
|
|
|
|
|
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks.
|
|
|
|
Below is an example for stronger control in middle block.
|
|
|
|
This is helpful for some high-res fix passes.
|
|
|
|
|
2024-01-28 16:03:36 +00:00
|
|
|
positive_advanced_weighting = {
|
|
|
|
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
|
|
|
|
'middle': [1.0],
|
|
|
|
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
|
|
|
}
|
|
|
|
negative_advanced_weighting = {
|
|
|
|
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
|
|
|
|
'middle': [1.0],
|
|
|
|
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
|
|
|
}
|
2024-01-28 15:02:58 +00:00
|
|
|
|
|
|
|
# advanced_frame_weighting
|
|
|
|
|
|
|
|
The advanced_frame_weighting is a weight applied to each image in a batch.
|
|
|
|
The length of this list must be same with batch size
|
|
|
|
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0]
|
|
|
|
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time.
|
|
|
|
|
2024-01-28 15:27:51 +00:00
|
|
|
# advanced_sigma_weighting
|
|
|
|
|
|
|
|
The advanced_sigma_weighting allows you to dynamically compute control
|
|
|
|
weights given diffusion timestep (sigma).
|
|
|
|
For example below code can softly make beginning steps stronger than ending steps.
|
|
|
|
|
2024-01-28 16:09:28 +00:00
|
|
|
sigma_max = unet.model.model_sampling.sigma_max
|
|
|
|
sigma_min = unet.model.model_sampling.sigma_min
|
2024-01-28 16:03:36 +00:00
|
|
|
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
|
2024-01-28 15:27:51 +00:00
|
|
|
|
2024-01-28 15:02:58 +00:00
|
|
|
"""
|
2024-01-28 04:42:38 +00:00
|
|
|
|
2024-01-28 06:44:49 +00:00
|
|
|
cnet = controlnet.copy().set_cond_hint(image_bhwc.movedim(-1, 1), strength, (start_percent, end_percent))
|
2024-01-28 05:26:26 +00:00
|
|
|
cnet.positive_advanced_weighting = positive_advanced_weighting
|
|
|
|
cnet.negative_advanced_weighting = negative_advanced_weighting
|
2024-01-28 08:26:49 +00:00
|
|
|
cnet.advanced_frame_weighting = advanced_frame_weighting
|
2024-01-28 15:27:51 +00:00
|
|
|
cnet.advanced_sigma_weighting = advanced_sigma_weighting
|
2024-01-28 04:42:38 +00:00
|
|
|
|
2024-01-28 03:53:42 +00:00
|
|
|
m = unet.clone()
|
2024-01-28 05:26:26 +00:00
|
|
|
m.add_patched_controlnet(cnet)
|
2024-01-28 03:53:42 +00:00
|
|
|
return m
|
2024-01-28 15:56:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
def compute_controlnet_weighting(
|
|
|
|
control,
|
|
|
|
positive_advanced_weighting,
|
|
|
|
negative_advanced_weighting,
|
|
|
|
advanced_frame_weighting,
|
|
|
|
advanced_sigma_weighting,
|
|
|
|
transformer_options
|
|
|
|
):
|
2024-01-28 16:03:36 +00:00
|
|
|
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
|
|
|
and advanced_frame_weighting is None and advanced_sigma_weighting is None:
|
|
|
|
return control
|
|
|
|
|
|
|
|
cond_or_uncond = transformer_options['cond_or_uncond']
|
|
|
|
sigmas = transformer_options['sigmas']
|
2024-01-28 16:27:43 +00:00
|
|
|
cond_mark = transformer_options['cond_mark']
|
2024-01-28 16:03:36 +00:00
|
|
|
|
|
|
|
if advanced_sigma_weighting is not None:
|
|
|
|
advanced_sigma_weighting = advanced_sigma_weighting(sigmas)
|
|
|
|
|
2024-01-28 15:56:42 +00:00
|
|
|
return control
|