add_sampler_pre_cfg_function
This commit is contained in:
parent
8f86e66e5c
commit
88f6df4dcd
@ -55,11 +55,17 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
||||
|
||||
unet = process.sd_model.forge_objects.unet.clone()
|
||||
|
||||
def pre_cfg(model, c, uc, x, timestep, model_options):
|
||||
noisy_latent = latent_image.to(x) + timestep.to(x) * torch.randn_like(latent_image).to(x)
|
||||
x = x * latent_mask.to(x) + noisy_latent.to(x) * (1.0 - latent_mask.to(x))
|
||||
return model, c, uc, x, timestep, model_options
|
||||
|
||||
def post_cfg(args):
|
||||
denoised = args['denoised']
|
||||
denoised = denoised * latent_mask.to(denoised) + latent_image.to(denoised) * (1.0 - latent_mask.to(denoised))
|
||||
return denoised
|
||||
|
||||
unet.add_sampler_pre_cfg_function(pre_cfg)
|
||||
unet.set_model_sampler_post_cfg_function(post_cfg)
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
@ -276,6 +276,9 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options)
|
||||
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
|
@ -102,6 +102,10 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_memory_peak_estimation_modifier(self, modifier):
|
||||
self.model_options['memory_peak_estimation_modifier'] = modifier
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user