Add advanced weighting support (#754)
This commit is contained in:
parent
49c3a080b5
commit
eb1e12b0dc
@ -227,6 +227,7 @@ class ControlNetUiGroup(object):
|
|||||||
self.hr_option = None
|
self.hr_option = None
|
||||||
self.batch_image_dir_state = None
|
self.batch_image_dir_state = None
|
||||||
self.output_dir_state = None
|
self.output_dir_state = None
|
||||||
|
self.advanced_weighting = gr.State(None)
|
||||||
|
|
||||||
# Internal states for UI state pasting.
|
# Internal states for UI state pasting.
|
||||||
self.prevent_next_n_module_update = 0
|
self.prevent_next_n_module_update = 0
|
||||||
@ -607,6 +608,7 @@ class ControlNetUiGroup(object):
|
|||||||
self.guidance_end,
|
self.guidance_end,
|
||||||
self.pixel_perfect,
|
self.pixel_perfect,
|
||||||
self.control_mode,
|
self.control_mode,
|
||||||
|
self.advanced_weighting,
|
||||||
)
|
)
|
||||||
|
|
||||||
unit = gr.State(self.default_unit)
|
unit = gr.State(self.default_unit)
|
||||||
|
@ -221,6 +221,21 @@ class ControlNetUnit:
|
|||||||
pixel_perfect: bool = False
|
pixel_perfect: bool = False
|
||||||
# Control mode for the unit; defaults to balanced.
|
# Control mode for the unit; defaults to balanced.
|
||||||
control_mode: ControlMode = ControlMode.BALANCED
|
control_mode: ControlMode = ControlMode.BALANCED
|
||||||
|
# Weight for each layer of ControlNet params.
|
||||||
|
# For ControlNet:
|
||||||
|
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
|
||||||
|
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
|
||||||
|
# For T2IAdapter
|
||||||
|
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
|
||||||
|
# - SDXL: 4 weights (3 encoder block + 1 middle block)
|
||||||
|
# For IPAdapter
|
||||||
|
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
|
||||||
|
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
|
||||||
|
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
|
||||||
|
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
|
||||||
|
# Note2: The field `weight` is still used in some places, e.g. reference_only,
|
||||||
|
# even advanced_weighting is set.
|
||||||
|
advanced_weighting: Optional[List[float]] = None
|
||||||
|
|
||||||
# Following fields should only be used in the API.
|
# Following fields should only be used in the API.
|
||||||
# ====== Start of API only fields ======
|
# ====== Start of API only fields ======
|
||||||
|
@ -59,6 +59,7 @@ class ControlNetCachedParameters:
|
|||||||
self.control_cond_for_hr_fix = None
|
self.control_cond_for_hr_fix = None
|
||||||
self.control_mask = None
|
self.control_mask = None
|
||||||
self.control_mask_for_hr_fix = None
|
self.control_mask_for_hr_fix = None
|
||||||
|
self.advanced_weighting = None
|
||||||
|
|
||||||
|
|
||||||
class ControlNetForForgeOfficial(scripts.Script):
|
class ControlNetForForgeOfficial(scripts.Script):
|
||||||
@ -505,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
params.model.positive_advanced_weighting = soft_weighting.copy()
|
params.model.positive_advanced_weighting = soft_weighting.copy()
|
||||||
params.model.negative_advanced_weighting = soft_weighting.copy()
|
params.model.negative_advanced_weighting = soft_weighting.copy()
|
||||||
|
|
||||||
|
if unit.advanced_weighting is not None:
|
||||||
|
if params.model.positive_advanced_weighting is None:
|
||||||
|
logger.warn("advanced_weighting overwrite control_mode")
|
||||||
|
params.model.positive_advanced_weighting = unit.advanced_weighting
|
||||||
|
|
||||||
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||||
|
|
||||||
params.model.advanced_mask_weighting = mask
|
params.model.advanced_mask_weighting = mask
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
from .template import (
|
||||||
|
APITestTemplate,
|
||||||
|
realistic_girl_face_img,
|
||||||
|
disable_in_cq,
|
||||||
|
get_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@disable_in_cq
|
||||||
|
def test_ipadapter_advanced_weighting():
|
||||||
|
weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL
|
||||||
|
# SD15 composition
|
||||||
|
weights[4] = 0.25
|
||||||
|
weights[5] = 1.0
|
||||||
|
|
||||||
|
APITestTemplate(
|
||||||
|
"test_ipadapter_advanced_weighting",
|
||||||
|
"txt2img",
|
||||||
|
payload_overrides={
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
},
|
||||||
|
unit_overrides={
|
||||||
|
"image": realistic_girl_face_img,
|
||||||
|
"module": "CLIP-ViT-H (IPAdapter)",
|
||||||
|
"model": get_model("ip-adapter_sd15"),
|
||||||
|
"advanced_weighting": weights,
|
||||||
|
},
|
||||||
|
).exec()
|
||||||
|
|
||||||
|
APITestTemplate(
|
||||||
|
"test_ipadapter_advanced_weighting_ref",
|
||||||
|
"txt2img",
|
||||||
|
payload_overrides={
|
||||||
|
"width": 512,
|
||||||
|
"height": 512,
|
||||||
|
},
|
||||||
|
unit_overrides={
|
||||||
|
"image": realistic_girl_face_img,
|
||||||
|
"module": "CLIP-ViT-H (IPAdapter)",
|
||||||
|
"model": get_model("ip-adapter_sd15"),
|
||||||
|
},
|
||||||
|
).exec()
|
@ -248,13 +248,13 @@ def get_model(model_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
default_unit = {
|
default_unit = {
|
||||||
"control_mode": 0,
|
"control_mode": "Balanced",
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"guidance_end": 1,
|
"guidance_end": 1,
|
||||||
"guidance_start": 0,
|
"guidance_start": 0,
|
||||||
"pixel_perfect": True,
|
"pixel_perfect": True,
|
||||||
"processor_res": 512,
|
"processor_res": 512,
|
||||||
"resize_mode": 1,
|
"resize_mode": "Crop and Resize",
|
||||||
"threshold_a": 64,
|
"threshold_a": 64,
|
||||||
"threshold_b": 64,
|
"threshold_b": 64,
|
||||||
"weight": 1,
|
"weight": 1,
|
||||||
|
@ -466,8 +466,18 @@ class CrossAttentionPatch:
|
|||||||
ip_v = ip_v_offset + ip_v_mean * W
|
ip_v = ip_v_offset + ip_v_mean * W
|
||||||
|
|
||||||
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
|
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
|
||||||
if weight_type.startswith("original"):
|
|
||||||
out_ip = out_ip * weight
|
if weight_type == "original":
|
||||||
|
assert isinstance(weight, (float, int))
|
||||||
|
weight = weight
|
||||||
|
elif weight_type == "advanced":
|
||||||
|
assert isinstance(weight, list)
|
||||||
|
transformer_index: int = extra_options["transformer_index"]
|
||||||
|
assert transformer_index < len(weight)
|
||||||
|
weight = weight[transformer_index]
|
||||||
|
else:
|
||||||
|
weight = 1.0
|
||||||
|
out_ip = out_ip * weight
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# TODO: needs checking
|
# TODO: needs checking
|
||||||
|
@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher):
|
|||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
unet = process.sd_model.forge_objects.unet
|
unet = process.sd_model.forge_objects.unet
|
||||||
|
if self.positive_advanced_weighting is None:
|
||||||
|
weight = self.strength
|
||||||
|
cond["weight_type"] = "original"
|
||||||
|
else:
|
||||||
|
weight = self.positive_advanced_weighting
|
||||||
|
cond["weight_type"] = "advanced"
|
||||||
|
|
||||||
unet = opIPAdapterApply(
|
unet = opIPAdapterApply(
|
||||||
ipadapter=self.ip_adapter,
|
ipadapter=self.ip_adapter,
|
||||||
model=unet,
|
model=unet,
|
||||||
weight=self.strength,
|
weight=weight,
|
||||||
start_at=self.start_percent,
|
start_at=self.start_percent,
|
||||||
end_at=self.end_percent,
|
end_at=self.end_percent,
|
||||||
faceid_v2=self.faceid_v2,
|
faceid_v2=self.faceid_v2,
|
||||||
|
Loading…
Reference in New Issue
Block a user