From 72139b000cf227910d63f0b93204037df7c34053 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 3 Mar 2024 20:09:04 -0600 Subject: [PATCH] fix alphas cumprod (#475) --- ldm_patched/modules/model_sampling.py | 8 +++++++- modules/processing.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index da5cc3a6..b438ea56 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -76,7 +76,13 @@ class ModelSamplingDiscrete(torch.nn.Module): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) def sigma(self, timestep): t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) diff --git a/modules/processing.py b/modules/processing.py index 2453bf99..4528866c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -917,11 +917,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_cumprod_backup = p.sd_model.alphas_cumprod for modifier in alphas_cumprod_modifiers: p.sd_model.alphas_cumprod = modifier(p.sd_model.alphas_cumprod) + p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if alphas_cumprod_backup is not None: p.sd_model.alphas_cumprod = alphas_cumprod_backup + p.sd_model.forge_objects.unet.model.model_sampling.set_sigmas(((1 - p.sd_model.alphas_cumprod) / p.sd_model.alphas_cumprod) ** 0.5) if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim)