implement align your steps scheduler (#726)
Signed-off-by: blob42 <contact@blob42.xyz>
This commit is contained in:
parent
62e60ad403
commit
49c3a080b5
@ -6,6 +6,7 @@ import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
||||
return append_zero(sigmas)
|
||||
|
||||
# align your steps
|
||||
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
|
||||
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
Performs log-linear interpolation of a given array of decreasing numbers.
|
||||
"""
|
||||
xs = torch.linspace(0, 1, len(t_steps))
|
||||
ys = torch.log(torch.tensor(t_steps[::-1]))
|
||||
|
||||
new_xs = torch.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
|
||||
return interped_ys
|
||||
|
||||
if is_sdxl:
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
|
||||
else:
|
||||
# Default to SD 1.5 sigmas.
|
||||
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
|
||||
|
||||
if n != len(sigmas):
|
||||
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
|
||||
else:
|
||||
sigmas.append(0.0)
|
||||
|
||||
return torch.FloatTensor(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
|
@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "ays":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
|
@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
self.sampler_name = sampler_name
|
||||
self.scheduler_name = scheduler_name
|
||||
self.unet = sd_model.forge_objects.unet
|
||||
self.model = sd_model
|
||||
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
super().__init__(sampler_function, sd_model, None)
|
||||
@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
sigmas = self.unet.model.model_sampling.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
else:
|
||||
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
|
||||
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
|
||||
return sigmas.to(self.unet.load_device)
|
||||
|
||||
|
||||
@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name):
|
||||
samplers_data_alter = [
|
||||
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
|
||||
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
|
||||
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
|
||||
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
|
||||
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),
|
||||
|
Loading…
Reference in New Issue
Block a user