From eb1e12b0dc536b6d5f5482d61d7f922c56f01f7a Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 23 May 2024 13:11:11 -0400 Subject: [PATCH] Add advanced weighting support (#754) --- .../controlnet_ui/controlnet_ui_group.py | 2 + .../lib_controlnet/external_code.py | 15 +++++++ .../sd_forge_controlnet/scripts/controlnet.py | 6 +++ .../web_api/ipadapter_advanced_weighting.py | 43 +++++++++++++++++++ .../tests/web_api/template.py | 4 +- .../lib_ipadapter/IPAdapterPlus.py | 16 +++++-- .../scripts/forge_ipadapter.py | 8 +++- 7 files changed, 88 insertions(+), 6 deletions(-) create mode 100644 extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index 928348dd..0f58b24a 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -227,6 +227,7 @@ class ControlNetUiGroup(object): self.hr_option = None self.batch_image_dir_state = None self.output_dir_state = None + self.advanced_weighting = gr.State(None) # Internal states for UI state pasting. self.prevent_next_n_module_update = 0 @@ -607,6 +608,7 @@ class ControlNetUiGroup(object): self.guidance_end, self.pixel_perfect, self.control_mode, + self.advanced_weighting, ) unit = gr.State(self.default_unit) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 4e47bd05..857387e0 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -221,6 +221,21 @@ class ControlNetUnit: pixel_perfect: bool = False # Control mode for the unit; defaults to balanced. control_mode: ControlMode = ControlMode.BALANCED + # Weight for each layer of ControlNet params. + # For ControlNet: + # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) + # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) + # For T2IAdapter + # - SD1.5: 5 weights (4 encoder block + 1 middle block) + # - SDXL: 4 weights (3 encoder block + 1 middle block) + # For IPAdapter + # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) + # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) + # Note1: Setting advanced weighting will disable `soft_injection`, i.e. + # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. + # Note2: The field `weight` is still used in some places, e.g. reference_only, + # even advanced_weighting is set. + advanced_weighting: Optional[List[float]] = None # Following fields should only be used in the API. # ====== Start of API only fields ====== diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index f6ba447c..436487a0 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -59,6 +59,7 @@ class ControlNetCachedParameters: self.control_cond_for_hr_fix = None self.control_mask = None self.control_mask_for_hr_fix = None + self.advanced_weighting = None class ControlNetForForgeOfficial(scripts.Script): @@ -505,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script): params.model.positive_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = soft_weighting.copy() + if unit.advanced_weighting is not None: + if params.model.positive_advanced_weighting is None: + logger.warn("advanced_weighting overwrite control_mode") + params.model.positive_advanced_weighting = unit.advanced_weighting + cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs) params.model.advanced_mask_weighting = mask diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py new file mode 100644 index 00000000..c61ef568 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py @@ -0,0 +1,43 @@ +from .template import ( + APITestTemplate, + realistic_girl_face_img, + disable_in_cq, + get_model, +) + + +@disable_in_cq +def test_ipadapter_advanced_weighting(): + weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL + # SD15 composition + weights[4] = 0.25 + weights[5] = 1.0 + + APITestTemplate( + "test_ipadapter_advanced_weighting", + "txt2img", + payload_overrides={ + "width": 512, + "height": 512, + }, + unit_overrides={ + "image": realistic_girl_face_img, + "module": "CLIP-ViT-H (IPAdapter)", + "model": get_model("ip-adapter_sd15"), + "advanced_weighting": weights, + }, + ).exec() + + APITestTemplate( + "test_ipadapter_advanced_weighting_ref", + "txt2img", + payload_overrides={ + "width": 512, + "height": 512, + }, + unit_overrides={ + "image": realistic_girl_face_img, + "module": "CLIP-ViT-H (IPAdapter)", + "model": get_model("ip-adapter_sd15"), + }, + ).exec() \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py index 5129e541..5a4eb49f 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py @@ -248,13 +248,13 @@ def get_model(model_name: str) -> str: default_unit = { - "control_mode": 0, + "control_mode": "Balanced", "enabled": True, "guidance_end": 1, "guidance_start": 0, "pixel_perfect": True, "processor_res": 512, - "resize_mode": 1, + "resize_mode": "Crop and Resize", "threshold_a": 64, "threshold_b": 64, "weight": 1, diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index 7e5a54f2..1350c1f6 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -403,7 +403,7 @@ class CrossAttentionPatch: batch_prompt = b // len(cond_or_uncond) out = optimized_attention(q, k, v, extra_options["n_heads"]) _, _, lh, lw = extra_options["original_shape"] - + for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch): if sigma > sigma_start or sigma < sigma_end: continue @@ -466,8 +466,18 @@ class CrossAttentionPatch: ip_v = ip_v_offset + ip_v_mean * W out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"]) - if weight_type.startswith("original"): - out_ip = out_ip * weight + + if weight_type == "original": + assert isinstance(weight, (float, int)) + weight = weight + elif weight_type == "advanced": + assert isinstance(weight, list) + transformer_index: int = extra_options["transformer_index"] + assert transformer_index < len(weight) + weight = weight[transformer_index] + else: + weight = 1.0 + out_ip = out_ip * weight if mask is not None: # TODO: needs checking diff --git a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py index 47b72044..c3d8c1a9 100644 --- a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py +++ b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py @@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): unet = process.sd_model.forge_objects.unet + if self.positive_advanced_weighting is None: + weight = self.strength + cond["weight_type"] = "original" + else: + weight = self.positive_advanced_weighting + cond["weight_type"] = "advanced" unet = opIPAdapterApply( ipadapter=self.ip_adapter, model=unet, - weight=self.strength, + weight=weight, start_at=self.start_percent, end_at=self.end_percent, faceid_v2=self.faceid_v2,