my-sd/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

86 lines
2.9 KiB
Python
Raw Normal View History

2024-01-31 18:28:27 +00:00
from modules_forge.supported_preprocessor import PreprocessorClipVision
from modules_forge.shared import add_supported_preprocessor
from modules_forge.forge_util import numpy_to_pytorch
2024-01-30 03:49:36 +00:00
from modules_forge.shared import add_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
2024-01-30 04:11:04 +00:00
from lib_ipadapter.IPAdapterPlus import IPAdapterApply
2024-01-30 04:01:42 +00:00
2024-01-30 04:14:14 +00:00
opIPAdapterApply = IPAdapterApply().apply_ipadapter
2024-01-30 03:49:36 +00:00
2024-01-31 18:28:27 +00:00
class PreprocessorClipVisionForIPAdapter(PreprocessorClipVision):
def __init__(self, name, url, filename):
super().__init__(name, url, filename)
self.tags = ['IP-Adapter']
self.model_filename_filters = ['IP-Adapter', 'IP_Adapter']
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
2024-01-31 18:34:27 +00:00
return dict(
clip_vision=self.load_clipvision(),
2024-01-31 18:32:04 +00:00
image=numpy_to_pytorch(input_image),
weight_type="original",
noise=0.0,
embeds=None,
attn_mask=None,
unfold_batch=False,
)
2024-01-31 18:28:27 +00:00
add_supported_preprocessor(PreprocessorClipVisionForIPAdapter(
name='CLIP-ViT-H (IPAdapter)',
url='https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/model.safetensors',
filename='CLIP-ViT-H-14.safetensors'
))
add_supported_preprocessor(PreprocessorClipVisionForIPAdapter(
name='CLIP-ViT-bigG (IPAdapter)',
url='https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/image_encoder/model.safetensors',
filename='CLIP-ViT-bigG.safetensors'
))
2024-01-30 03:49:36 +00:00
class IPAdapterPatcher(ControlModelPatcher):
@staticmethod
def try_build_from_state_dict(state_dict, ckpt_path):
2024-01-30 04:01:42 +00:00
model = state_dict
if ckpt_path.lower().endswith(".safetensors"):
st_model = {"image_proj": {}, "ip_adapter": {}}
for key in model.keys():
if key.startswith("image_proj."):
st_model["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
st_model["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
model = st_model
2024-01-30 06:07:41 +00:00
if "ip_adapter" not in model.keys() or len(model["ip_adapter"]) == 0:
2024-01-30 04:01:42 +00:00
return None
return IPAdapterPatcher(model)
2024-01-30 03:49:36 +00:00
2024-01-30 07:17:57 +00:00
def __init__(self, state_dict):
super().__init__()
self.ip_adapter = state_dict
2024-01-30 04:01:42 +00:00
return
2024-01-30 04:06:01 +00:00
2024-01-30 04:01:42 +00:00
def process_before_every_sampling(self, process, cond, *args, **kwargs):
2024-01-30 04:14:14 +00:00
unet = process.sd_model.forge_objects.unet
unet = opIPAdapterApply(
2024-01-30 07:17:57 +00:00
ipadapter=self.ip_adapter,
2024-01-30 04:14:14 +00:00
model=unet,
weight=self.strength,
2024-01-30 04:26:55 +00:00
start_at=self.start_percent,
end_at=self.end_percent,
2024-01-30 04:14:14 +00:00
faceid_v2=False,
2024-01-31 18:32:04 +00:00
weight_v2=False,
**cond,
2024-01-30 04:14:14 +00:00
)[0]
process.sd_model.forge_objects.unet = unet
2024-01-30 04:01:42 +00:00
return
2024-01-30 03:49:36 +00:00
add_supported_control_model(IPAdapterPatcher)