Only feed image embed to instant id ControlNet unit (#309)
* Only feed image embed to instant id ControlNet units * Add tests * Add more tests
This commit is contained in:
parent
b55f9e7212
commit
d5cc799eec
@ -4,6 +4,7 @@ from .template import (
|
||||
APITestTemplate,
|
||||
girl_img,
|
||||
mask_img,
|
||||
portrait_imgs,
|
||||
disable_in_cq,
|
||||
get_model,
|
||||
)
|
||||
@ -169,3 +170,60 @@ def test_lama_outpaint():
|
||||
"resize_mode": "Resize and Fill", # OUTER_FIT
|
||||
},
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_instant_id_sdxl():
|
||||
assert len(portrait_imgs) > 0
|
||||
assert APITestTemplate(
|
||||
"instant_id_sdxl",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 1000,
|
||||
"height": 1000,
|
||||
"prompt": "1girl, red background",
|
||||
},
|
||||
unit_overrides=[
|
||||
dict(
|
||||
image=portrait_imgs[0],
|
||||
model=get_model("ip-adapter_instant_id_sdxl"),
|
||||
module="InsightFace (InstantID)",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("control_instant_id_sdxl"),
|
||||
module="instant_id_face_keypoints",
|
||||
),
|
||||
],
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_instant_id_sdxl_multiple_units():
|
||||
assert len(portrait_imgs) > 0
|
||||
assert APITestTemplate(
|
||||
"instant_id_sdxl_multiple_units",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 1000,
|
||||
"height": 1000,
|
||||
"prompt": "1girl, red background",
|
||||
},
|
||||
unit_overrides=[
|
||||
dict(
|
||||
image=portrait_imgs[0],
|
||||
model=get_model("ip-adapter_instant_id_sdxl"),
|
||||
module="InsightFace (InstantID)",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("control_instant_id_sdxl"),
|
||||
module="instant_id_face_keypoints",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("diffusers_xl_canny"),
|
||||
module="canny",
|
||||
),
|
||||
],
|
||||
).exec()
|
||||
|
@ -7,6 +7,7 @@ import math
|
||||
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.model_management
|
||||
from ldm_patched.modules.controlnet import ControlNet
|
||||
from ldm_patched.modules.clip_vision import clip_preprocess
|
||||
from ldm_patched.ldm.modules.attention import optimized_attention
|
||||
from ldm_patched.utils import path_utils as folder_paths
|
||||
@ -749,13 +750,27 @@ class IPAdapterApply:
|
||||
work_model = model.clone()
|
||||
|
||||
if self.is_instant_id:
|
||||
def modifier(cnet, x_noisy, t, cond, batched_number):
|
||||
def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number):
|
||||
"""Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds.
|
||||
|
||||
TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control
|
||||
rendering of multiple faces on canvas. We need to find a way to pair them. Currently,
|
||||
the modifier is unconditionally applied to all instant id ControlNet units.
|
||||
"""
|
||||
if (
|
||||
not isinstance(cnet, ControlNet) or
|
||||
# model_file_name is None for Control LoRA.
|
||||
cnet.control_model.model_file_name is None or
|
||||
"instant_id" not in cnet.control_model.model_file_name.lower()
|
||||
):
|
||||
return x_noisy, t, cond, batched_number
|
||||
|
||||
cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0
|
||||
c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark
|
||||
cond['c_crossattn'] = c_crossattn
|
||||
return x_noisy, t, cond, batched_number
|
||||
|
||||
work_model.add_controlnet_conditioning_modifier(modifier)
|
||||
work_model.add_controlnet_conditioning_modifier(instant_id_modifier)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(self.device)
|
||||
|
@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
from ldm_patched.ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
@ -54,9 +55,12 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=ldm_patched.modules.ops.disable_weight_init,
|
||||
model_file_name: Optional[str] = None, # Name of model file.
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_file_name = model_file_name
|
||||
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
|
@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
|
Loading…
Reference in New Issue
Block a user