From 24db0e241a2a1d2c80276f61602be1779924d760 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 1 Feb 2024 23:04:24 -0800 Subject: [PATCH] elegant codes --- .../sd_forge_photomaker/scripts/forge_photomaker.py | 9 +-------- modules_forge/unet_patcher.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py b/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py index 37cab46d..9d230be8 100644 --- a/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py +++ b/extensions-builtin/sd_forge_photomaker/scripts/forge_photomaker.py @@ -52,14 +52,7 @@ class PhotomakerPatcher(ControlModelPatcher): text = process.prompts[0] cond_modified = opPhotoMakerEncode(photomaker=self.model, image=cond.movedim(1, -1), clip=clip, text=text)[0] - noise = kwargs['x'] - cond_modified = encode_model_conds( - model_function=unet.model.extra_conds, - conds=convert_cond(cond_modified), - noise=noise, - device=noise.device, - prompt_type="positive" - )[0] + cond_modified = unet.encode_conds_from_clip(conds=cond_modified, noise=kwargs['x'])[0] def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed): cond = cond.copy() diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 8e2b6f82..89ebee12 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -1,6 +1,8 @@ import copy from ldm_patched.modules.model_patcher import ModelPatcher +from ldm_patched.modules.sample import convert_cond +from ldm_patched.modules.samplers import encode_model_conds class UnetPatcher(ModelPatcher): @@ -96,3 +98,12 @@ class UnetPatcher(ModelPatcher): for transformer_index in range(16): self.set_model_patch_replace(patch, target, block_name, number, transformer_index) return + + def encode_conds_from_clip(self, conds, noise, prompt_type="positive"): + return encode_model_conds( + model_function=self.model.extra_conds, + conds=convert_cond(conds), + noise=noise, + device=noise.device, + prompt_type=prompt_type + )