From d5cc799eeccd98d21980801143dfb9308480bee5 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Thu, 2 May 2024 18:30:54 -0400 Subject: [PATCH] Only feed image embed to instant id ControlNet unit (#309) * Only feed image embed to instant id ControlNet units * Add tests * Add more tests --- .../tests/web_api/generation_test.py | 58 +++++++++++++++++++ .../lib_ipadapter/IPAdapterPlus.py | 21 ++++++- ldm_patched/controlnet/cldm.py | 4 ++ ldm_patched/modules/controlnet.py | 2 +- modules_forge/supported_controlnet.py | 2 +- 5 files changed, 82 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py index 433819d1..8b8a68c9 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py @@ -4,6 +4,7 @@ from .template import ( APITestTemplate, girl_img, mask_img, + portrait_imgs, disable_in_cq, get_model, ) @@ -169,3 +170,60 @@ def test_lama_outpaint(): "resize_mode": "Resize and Fill", # OUTER_FIT }, ).exec() + + +@disable_in_cq +def test_instant_id_sdxl(): + assert len(portrait_imgs) > 0 + assert APITestTemplate( + "instant_id_sdxl", + "txt2img", + payload_overrides={ + "width": 1000, + "height": 1000, + "prompt": "1girl, red background", + }, + unit_overrides=[ + dict( + image=portrait_imgs[0], + model=get_model("ip-adapter_instant_id_sdxl"), + module="InsightFace (InstantID)", + ), + dict( + image=portrait_imgs[1], + model=get_model("control_instant_id_sdxl"), + module="instant_id_face_keypoints", + ), + ], + ).exec() + + +@disable_in_cq +def test_instant_id_sdxl_multiple_units(): + assert len(portrait_imgs) > 0 + assert APITestTemplate( + "instant_id_sdxl_multiple_units", + "txt2img", + payload_overrides={ + "width": 1000, + "height": 1000, + "prompt": "1girl, red background", + }, + unit_overrides=[ + dict( + image=portrait_imgs[0], + model=get_model("ip-adapter_instant_id_sdxl"), + module="InsightFace (InstantID)", + ), + dict( + image=portrait_imgs[1], + model=get_model("control_instant_id_sdxl"), + module="instant_id_face_keypoints", + ), + dict( + image=portrait_imgs[1], + model=get_model("diffusers_xl_canny"), + module="canny", + ), + ], + ).exec() diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index e35b8d60..7e5a54f2 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -7,6 +7,7 @@ import math import ldm_patched.modules.utils import ldm_patched.modules.model_management +from ldm_patched.modules.controlnet import ControlNet from ldm_patched.modules.clip_vision import clip_preprocess from ldm_patched.ldm.modules.attention import optimized_attention from ldm_patched.utils import path_utils as folder_paths @@ -732,7 +733,7 @@ class IPAdapterApply: is_faceid=self.is_faceid, is_instant_id=self.is_instant_id ) - + self.ipadapter.to(self.device, dtype=self.dtype) if self.is_instant_id: @@ -749,13 +750,27 @@ class IPAdapterApply: work_model = model.clone() if self.is_instant_id: - def modifier(cnet, x_noisy, t, cond, batched_number): + def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number): + """Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds. + + TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control + rendering of multiple faces on canvas. We need to find a way to pair them. Currently, + the modifier is unconditionally applied to all instant id ControlNet units. + """ + if ( + not isinstance(cnet, ControlNet) or + # model_file_name is None for Control LoRA. + cnet.control_model.model_file_name is None or + "instant_id" not in cnet.control_model.model_file_name.lower() + ): + return x_noisy, t, cond, batched_number + cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0 c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark cond['c_crossattn'] = c_crossattn return x_noisy, t, cond, batched_number - work_model.add_controlnet_conditioning_modifier(modifier) + work_model.add_controlnet_conditioning_modifier(instant_id_modifier) if attn_mask is not None: attn_mask = attn_mask.to(self.device) diff --git a/ldm_patched/controlnet/cldm.py b/ldm_patched/controlnet/cldm.py index 82265ef9..aa8de9bd 100644 --- a/ldm_patched/controlnet/cldm.py +++ b/ldm_patched/controlnet/cldm.py @@ -4,6 +4,7 @@ import torch import torch as th import torch.nn as nn +from typing import Optional from ldm_patched.ldm.modules.diffusionmodules.util import ( zero_module, @@ -54,9 +55,12 @@ class ControlNet(nn.Module): transformer_depth_output=None, device=None, operations=ldm_patched.modules.ops.disable_weight_init, + model_file_name: Optional[str] = None, # Name of model file. **kwargs, ): super().__init__() + self.model_file_name = model_file_name + assert use_spatial_transformer == True, "use_spatial_transformer has to be true" if use_spatial_transformer: assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 6192d7ae..449716c4 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None): controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config) if pth: if 'difference' in controlnet_data: diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 1490259e..747fe583 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config) if pth: if 'difference' in controlnet_data: