my-sd/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

60 lines
1.9 KiB
Python
Raw Normal View History

2024-01-30 03:49:36 +00:00
from modules_forge.shared import add_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
2024-01-30 04:11:04 +00:00
from lib_ipadapter.IPAdapterPlus import IPAdapterApply
2024-01-30 04:01:42 +00:00
2024-01-30 04:14:14 +00:00
opIPAdapterApply = IPAdapterApply().apply_ipadapter
2024-01-30 03:49:36 +00:00
class IPAdapterPatcher(ControlModelPatcher):
@staticmethod
def try_build_from_state_dict(state_dict, ckpt_path):
2024-01-30 04:01:42 +00:00
model = state_dict
if ckpt_path.lower().endswith(".safetensors"):
st_model = {"image_proj": {}, "ip_adapter": {}}
for key in model.keys():
if key.startswith("image_proj."):
st_model["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
model = st_model
if "ip_adapter" not in model.keys() or not model["ip_adapter"]:
return None
return IPAdapterPatcher(model)
2024-01-30 03:49:36 +00:00
def __init__(self, model_patcher):
super().__init__(model_patcher)
2024-01-30 04:01:42 +00:00
self.ipadapter = model_patcher
return
2024-01-30 04:06:01 +00:00
2024-01-30 04:01:42 +00:00
def process_before_every_sampling(self, process, cond, *args, **kwargs):
2024-01-30 04:11:04 +00:00
clip_vision, image = cond
2024-01-30 04:14:14 +00:00
unet = process.sd_model.forge_objects.unet
unet = opIPAdapterApply(
ipadapter=self.ipadapter,
model=unet,
weight=self.strength,
clip_vision=clip_vision,
image=image,
weight_type="original",
2024-01-30 04:33:55 +00:00
noise=0.0,
2024-01-30 04:14:14 +00:00
embeds=None,
attn_mask=None,
2024-01-30 04:26:55 +00:00
start_at=self.start_percent,
end_at=self.end_percent,
2024-01-30 04:14:14 +00:00
unfold_batch=False,
insightface=None,
faceid_v2=False,
weight_v2=False
)[0]
process.sd_model.forge_objects.unet = unet
2024-01-30 04:01:42 +00:00
return
2024-01-30 03:49:36 +00:00
add_supported_control_model(IPAdapterPatcher)