my-sd/extensions-builtin/sd_forge_controlllite/scripts/forge_controllllite.py
lllyasviel 4f6b8ad079 i
2024-02-01 21:26:42 -08:00

39 lines
1.1 KiB
Python

from modules_forge.shared import add_supported_control_model
from modules_forge.supported_controlnet import ControlModelPatcher
from lib_controllllite.lib_controllllite import LLLiteLoader
opLLLiteLoader = LLLiteLoader().load_lllite
class ControlLLLitePatcher(ControlModelPatcher):
@staticmethod
def try_build_from_state_dict(state_dict, ckpt_path):
if not any('lllite' in k for k in state_dict.keys()):
return None
return ControlLLLitePatcher(state_dict)
def __init__(self, state_dict):
super().__init__()
self.state_dict = state_dict
return
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet
unet = opLLLiteLoader(
model=unet,
state_dict=self.state_dict,
cond_image=cond.movedim(1, -1),
strength=self.strength,
steps=process.steps,
start_percent=self.start_percent,
end_percent=self.end_percent
)[0]
process.sd_model.forge_objects.unet = unet
return
add_supported_control_model(ControlLLLitePatcher)