From 49c3a080b5fdc624cf0ac46030c2d37eec68c57e Mon Sep 17 00:00:00 2001 From: Chakib Benziane Date: Wed, 22 May 2024 21:11:10 +0200 Subject: [PATCH] implement align your steps scheduler (#726) Signed-off-by: blob42 --- ldm_patched/k_diffusion/sampling.py | 30 +++++++++++++++++++++++++++ ldm_patched/modules/samplers.py | 6 ++++-- modules_forge/forge_alter_samplers.py | 7 ++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/ldm_patched/k_diffusion/sampling.py b/ldm_patched/k_diffusion/sampling.py index 6f2fbea7..498d1a6f 100644 --- a/ldm_patched/k_diffusion/sampling.py +++ b/ldm_patched/k_diffusion/sampling.py @@ -6,6 +6,7 @@ import math from scipy import integrate import torch +import numpy as np from torch import nn import torchsde from tqdm.auto import trange, tqdm @@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) return append_zero(sigmas) +# align your steps +def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'): + # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html + def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = torch.linspace(0, 1, len(t_steps)) + ys = torch.log(torch.tensor(t_steps[::-1])) + + new_xs = torch.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy() + return interped_ys + + if is_sdxl: + sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] + else: + # Default to SD 1.5 sigmas. + sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] + + if n != len(sigmas): + sigmas = np.append(loglinear_interp(sigmas, n), [0.0]) + else: + sigmas.append(0.0) + + return torch.FloatTensor(sigmas).to(device) + def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): """Constructs a continuous VP noise schedule.""" diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index e8f53e13..7f49907a 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False): if scheduler_name == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "exponential": sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + elif scheduler_name == "ays": + sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl) elif scheduler_name == "normal": sigmas = normal_scheduler(model, steps) elif scheduler_name == "simple": diff --git a/modules_forge/forge_alter_samplers.py b/modules_forge/forge_alter_samplers.py index 4e482208..bb9f2499 100644 --- a/modules_forge/forge_alter_samplers.py +++ b/modules_forge/forge_alter_samplers.py @@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): self.sampler_name = sampler_name self.scheduler_name = scheduler_name self.unet = sd_model.forge_objects.unet + self.model = sd_model sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) super().__init__(sampler_function, sd_model, None) @@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): sigmas = self.unet.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) else: - sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps) + sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False)) return sigmas.to(self.unet.load_device) @@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name): samplers_data_alter = [ sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}), sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}), + sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}), sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}), + sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}), sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}), + sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}), sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}), + sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}), sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}), sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}), sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),