add_sampler_pre_cfg_function

This commit is contained in:
lllyasviel 2024-02-05 14:48:43 -08:00
parent 8f86e66e5c
commit 88f6df4dcd
3 changed files with 13 additions and 0 deletions

View File

@ -55,11 +55,17 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
unet = process.sd_model.forge_objects.unet.clone()
def pre_cfg(model, c, uc, x, timestep, model_options):
noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x)
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
return model, c, uc, x, timestep, model_options
def post_cfg(args):
denoised = args['denoised']
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
return denoised
unet.add_sampler_pre_cfg_function(pre_cfg)
unet.set_model_sampler_post_cfg_function(post_cfg)
process.sd_model.forge_objects.unet = unet

View File

@ -276,6 +276,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
else:
uncond_ = uncond
for fn in model_options.get("sampler_pre_cfg_function", []):
model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options)
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,

View File

@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher):
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
return
def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False):
self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness)
return
def set_memory_peak_estimation_modifier(self, modifier):
self.model_options['memory_peak_estimation_modifier'] = modifier
return