2024-01-28 16:52:15 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
2024-01-28 03:52:31 +00:00
|
|
|
def apply_controlnet_advanced(
|
|
|
|
unet,
|
|
|
|
controlnet,
|
2024-01-28 17:02:50 +00:00
|
|
|
image_bchw,
|
2024-01-28 03:52:31 +00:00
|
|
|
strength,
|
|
|
|
start_percent,
|
|
|
|
end_percent,
|
|
|
|
positive_advanced_weighting=None,
|
2024-01-28 07:47:17 +00:00
|
|
|
negative_advanced_weighting=None,
|
2024-01-28 08:26:49 +00:00
|
|
|
advanced_frame_weighting=None,
|
2024-02-01 07:11:12 +00:00
|
|
|
advanced_sigma_weighting=None,
|
|
|
|
advanced_mask_weighting=None
|
2024-01-28 07:47:17 +00:00
|
|
|
):
|
2024-01-28 15:02:58 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
# positive_advanced_weighting or negative_advanced_weighting
|
|
|
|
|
|
|
|
Unet has input, middle, output blocks, and we can give different weights to each layers in all blocks.
|
|
|
|
Below is an example for stronger control in middle block.
|
|
|
|
This is helpful for some high-res fix passes.
|
|
|
|
|
2024-01-28 16:03:36 +00:00
|
|
|
positive_advanced_weighting = {
|
|
|
|
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
|
|
|
|
'middle': [1.0],
|
|
|
|
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
|
|
|
}
|
|
|
|
negative_advanced_weighting = {
|
|
|
|
'input': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
|
|
|
|
'middle': [1.0],
|
|
|
|
'output': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
|
|
|
}
|
2024-01-28 15:02:58 +00:00
|
|
|
|
|
|
|
# advanced_frame_weighting
|
|
|
|
|
|
|
|
The advanced_frame_weighting is a weight applied to each image in a batch.
|
|
|
|
The length of this list must be same with batch size
|
|
|
|
For example, if batch size is 5, you can use advanced_frame_weighting = [0, 0.25, 0.5, 0.75, 1.0]
|
|
|
|
If you view the 5 images as 5 frames in a video, this will lead to progressively stronger control over time.
|
|
|
|
|
2024-01-28 15:27:51 +00:00
|
|
|
# advanced_sigma_weighting
|
|
|
|
|
|
|
|
The advanced_sigma_weighting allows you to dynamically compute control
|
|
|
|
weights given diffusion timestep (sigma).
|
|
|
|
For example below code can softly make beginning steps stronger than ending steps.
|
|
|
|
|
2024-01-28 16:09:28 +00:00
|
|
|
sigma_max = unet.model.model_sampling.sigma_max
|
|
|
|
sigma_min = unet.model.model_sampling.sigma_min
|
2024-01-28 16:03:36 +00:00
|
|
|
advanced_sigma_weighting = lambda s: (s - sigma_min) / (sigma_max - sigma_min)
|
2024-01-28 15:27:51 +00:00
|
|
|
|
2024-02-01 07:11:12 +00:00
|
|
|
# advanced_mask_weighting
|
|
|
|
|
|
|
|
A mask can be applied to control signals.
|
|
|
|
This should be a tensor with shape B 1 H W where the H and W can be arbitrary.
|
|
|
|
This mask will be resized automatically to match the shape of all injection layers.
|
|
|
|
|
2024-01-28 15:02:58 +00:00
|
|
|
"""
|
2024-01-28 04:42:38 +00:00
|
|
|
|
2024-01-28 17:02:50 +00:00
|
|
|
cnet = controlnet.copy().set_cond_hint(image_bchw, strength, (start_percent, end_percent))
|
2024-01-28 05:26:26 +00:00
|
|
|
cnet.positive_advanced_weighting = positive_advanced_weighting
|
|
|
|
cnet.negative_advanced_weighting = negative_advanced_weighting
|
2024-01-28 08:26:49 +00:00
|
|
|
cnet.advanced_frame_weighting = advanced_frame_weighting
|
2024-01-28 15:27:51 +00:00
|
|
|
cnet.advanced_sigma_weighting = advanced_sigma_weighting
|
2024-01-28 04:42:38 +00:00
|
|
|
|
2024-02-01 07:11:12 +00:00
|
|
|
if advanced_mask_weighting is not None:
|
|
|
|
assert isinstance(advanced_mask_weighting, torch.Tensor)
|
|
|
|
B, C, H, W = advanced_mask_weighting.shape
|
|
|
|
assert B > 0 and C == 1 and H > 0 and W > 0
|
|
|
|
|
|
|
|
cnet.advanced_mask_weighting = advanced_mask_weighting
|
|
|
|
|
2024-01-28 03:53:42 +00:00
|
|
|
m = unet.clone()
|
2024-01-28 05:26:26 +00:00
|
|
|
m.add_patched_controlnet(cnet)
|
2024-01-28 03:53:42 +00:00
|
|
|
return m
|
2024-01-28 15:56:42 +00:00
|
|
|
|