Only feed image embed to instant id ControlNet unit (#309)

* Only feed image embed to instant id ControlNet units

* Add tests

* Add more tests
This commit is contained in:
Chenlei Hu 2024-05-02 18:30:54 -04:00 committed by GitHub
parent b55f9e7212
commit d5cc799eec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 5 deletions

View File

@ -4,6 +4,7 @@ from .template import (
APITestTemplate,
girl_img,
mask_img,
portrait_imgs,
disable_in_cq,
get_model,
)
@ -169,3 +170,60 @@ def test_lama_outpaint():
"resize_mode": "Resize and Fill", # OUTER_FIT
},
).exec()
@disable_in_cq
def test_instant_id_sdxl():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
],
).exec()
@disable_in_cq
def test_instant_id_sdxl_multiple_units():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl_multiple_units",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
dict(
image=portrait_imgs[1],
model=get_model("diffusers_xl_canny"),
module="canny",
),
],
).exec()

View File

@ -7,6 +7,7 @@ import math
import ldm_patched.modules.utils
import ldm_patched.modules.model_management
from ldm_patched.modules.controlnet import ControlNet
from ldm_patched.modules.clip_vision import clip_preprocess
from ldm_patched.ldm.modules.attention import optimized_attention
from ldm_patched.utils import path_utils as folder_paths
@ -749,13 +750,27 @@ class IPAdapterApply:
work_model = model.clone()
if self.is_instant_id:
def modifier(cnet, x_noisy, t, cond, batched_number):
def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number):
"""Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds.
TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control
rendering of multiple faces on canvas. We need to find a way to pair them. Currently,
the modifier is unconditionally applied to all instant id ControlNet units.
"""
if (
not isinstance(cnet, ControlNet) or
# model_file_name is None for Control LoRA.
cnet.control_model.model_file_name is None or
"instant_id" not in cnet.control_model.model_file_name.lower()
):
return x_noisy, t, cond, batched_number
cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0
c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark
cond['c_crossattn'] = c_crossattn
return x_noisy, t, cond, batched_number
work_model.add_controlnet_conditioning_modifier(modifier)
work_model.add_controlnet_conditioning_modifier(instant_id_modifier)
if attn_mask is not None:
attn_mask = attn_mask.to(self.device)

View File

@ -4,6 +4,7 @@
import torch
import torch as th
import torch.nn as nn
from typing import Optional
from ldm_patched.ldm.modules.diffusionmodules.util import (
zero_module,
@ -54,9 +55,12 @@ class ControlNet(nn.Module):
transformer_depth_output=None,
device=None,
operations=ldm_patched.modules.ops.disable_weight_init,
model_file_name: Optional[str] = None, # Name of model file.
**kwargs,
):
super().__init__()
self.model_file_name = model_file_name
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

View File

@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth:
if 'difference' in controlnet_data:

View File

@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth:
if 'difference' in controlnet_data: