Update sd_samplers_cfg_denoiser.py

This commit is contained in:
lllyasviel 2024-01-25 19:32:17 -08:00
parent 4cd7437b62
commit c5d26cc807

View File

@ -65,7 +65,7 @@ class CFGDenoiser(torch.nn.Module):
def inner_model(self):
raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x):
model_options = self.inner_model.inner_model.forge_objects.unet.model_options
denoised_uncond = x_out[-uncond.shape[0]:]
@ -75,6 +75,12 @@ class CFGDenoiser(torch.nn.Module):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options:
cond_scale = float(cond_scale)
model = self.inner_model.inner_model.forge_objects.unet
a = 0
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
@ -239,9 +245,9 @@ class CFGDenoiser(torch.nn.Module):
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None: