implement align your steps scheduler (#726)

Signed-off-by: blob42 <contact@blob42.xyz>
This commit is contained in:
Chakib Benziane 2024-05-22 21:11:10 +02:00 committed by GitHub
parent 62e60ad403
commit 49c3a080b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 3 deletions

View File

@ -6,6 +6,7 @@ import math
from scipy import integrate
import torch
import numpy as np
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
# align your steps
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = torch.linspace(0, 1, len(t_steps))
ys = torch.log(torch.tensor(t_steps[::-1]))
new_xs = torch.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
return interped_ys
if is_sdxl:
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
else:
# Default to SD 1.5 sigmas.
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
if n != len(sigmas):
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
else:
sigmas.append(0.0)
return torch.FloatTensor(sigmas).to(device)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""

View File

@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "ays":
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple":

View File

@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
self.sampler_name = sampler_name
self.scheduler_name = scheduler_name
self.unet = sd_model.forge_objects.unet
self.model = sd_model
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
sigmas = self.unet.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
else:
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
return sigmas.to(self.unet.load_device)
@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name):
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),