new samplers
DDPM DDPM Karras DPM++ 2M Turbo DPM++ 2M SDE Turbo LCM Karras
This commit is contained in:
parent
af017121d2
commit
0ba407fd9c
@ -2,11 +2,13 @@ from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_l
|
|||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
from modules_forge import forge_alter_samplers
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||||
*sd_samplers_lcm.samplers_data_lcm,
|
*sd_samplers_lcm.samplers_data_lcm,
|
||||||
|
*forge_alter_samplers.samplers_data_alter
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
|
40
modules_forge/forge_alter_samplers.py
Normal file
40
modules_forge/forge_alter_samplers.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from modules import sd_samplers_kdiffusion, sd_samplers_common
|
||||||
|
|
||||||
|
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
|
||||||
|
from ldm_patched.modules.samplers import calculate_sigmas_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||||
|
def __init__(self, sd_model, sampler_name, scheduler_name):
|
||||||
|
self.sampler_name = sampler_name
|
||||||
|
self.scheduler_name = scheduler_name
|
||||||
|
self.unet = sd_model.forge_objects.unet
|
||||||
|
|
||||||
|
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||||
|
super().__init__(sampler_function, sd_model, None)
|
||||||
|
|
||||||
|
def get_sigmas(self, p, steps):
|
||||||
|
if self.scheduler_name == 'turbo':
|
||||||
|
timesteps = torch.flip(torch.arange(1, steps + 1) * float(1000.0 / steps) - 1, (0,)).round().long().clip(0, 999)
|
||||||
|
sigmas = self.unet.model.model_sampling.sigma(timesteps)
|
||||||
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||||
|
else:
|
||||||
|
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
|
||||||
|
return sigmas.to(self.unet.load_device)
|
||||||
|
|
||||||
|
|
||||||
|
def build_constructor(sampler_name, scheduler_name):
|
||||||
|
def constructor(m):
|
||||||
|
return AlterSampler(m, sampler_name, scheduler_name)
|
||||||
|
|
||||||
|
return constructor
|
||||||
|
|
||||||
|
|
||||||
|
samplers_data_alter = [
|
||||||
|
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
|
||||||
|
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
|
||||||
|
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
|
||||||
|
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
|
||||||
|
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
|
||||||
|
]
|
Loading…
Reference in New Issue
Block a user