3d039591fe
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
180 lines
6.8 KiB
Python
180 lines
6.8 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
# This file is only for reference, and not used in the backend or runtime.
|
|
|
|
|
|
import ldm_patched.utils.path_utils
|
|
import ldm_patched.modules.sd
|
|
import ldm_patched.modules.model_sampling
|
|
import torch
|
|
|
|
class LCM(ldm_patched.modules.model_sampling.EPS):
|
|
def calculate_denoised(self, sigma, model_output, model_input):
|
|
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
|
x0 = model_input - model_output * sigma
|
|
|
|
sigma_data = 0.5
|
|
scaled_timestep = timestep * 10.0 #timestep_scaling
|
|
|
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
|
|
|
return c_out * x0 + c_skip * model_input
|
|
|
|
class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete):
|
|
original_timesteps = 50
|
|
|
|
def __init__(self, model_config=None):
|
|
super().__init__(model_config)
|
|
|
|
self.skip_steps = self.num_timesteps // self.original_timesteps
|
|
|
|
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
|
for x in range(self.original_timesteps):
|
|
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
|
|
|
|
self.set_sigmas(sigmas_valid)
|
|
|
|
def timestep(self, sigma):
|
|
log_sigma = sigma.log()
|
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
|
return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
|
|
|
|
def sigma(self, timestep):
|
|
t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
|
low_idx = t.floor().long()
|
|
high_idx = t.ceil().long()
|
|
w = t.frac()
|
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
|
return log_sigma.exp().to(timestep.device)
|
|
|
|
|
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
|
|
# Store old values.
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
|
|
# Shift so the last timestep is zero.
|
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
|
|
|
# Scale so the first timestep is back to the old value.
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
|
|
# Convert alphas_bar_sqrt to betas
|
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
|
alphas_bar[-1] = 4.8973451890853435e-08
|
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
|
|
|
class ModelSamplingDiscrete:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"sampling": (["eps", "v_prediction", "lcm"],),
|
|
"zsnr": ("BOOLEAN", {"default": False}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "advanced/model"
|
|
|
|
def patch(self, model, sampling, zsnr):
|
|
m = model.clone()
|
|
|
|
sampling_base = ldm_patched.modules.model_sampling.ModelSamplingDiscrete
|
|
if sampling == "eps":
|
|
sampling_type = ldm_patched.modules.model_sampling.EPS
|
|
elif sampling == "v_prediction":
|
|
sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
|
|
elif sampling == "lcm":
|
|
sampling_type = LCM
|
|
sampling_base = ModelSamplingDiscreteDistilled
|
|
|
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
|
pass
|
|
|
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
|
if zsnr:
|
|
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
|
|
|
m.add_object_patch("model_sampling", model_sampling)
|
|
return (m, )
|
|
|
|
class ModelSamplingContinuousEDM:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"sampling": (["v_prediction", "eps"],),
|
|
"sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
|
"sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "advanced/model"
|
|
|
|
def patch(self, model, sampling, sigma_max, sigma_min):
|
|
m = model.clone()
|
|
|
|
if sampling == "eps":
|
|
sampling_type = ldm_patched.modules.model_sampling.EPS
|
|
elif sampling == "v_prediction":
|
|
sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
|
|
|
|
class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
|
pass
|
|
|
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
|
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
|
m.add_object_patch("model_sampling", model_sampling)
|
|
return (m, )
|
|
|
|
class RescaleCFG:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "model": ("MODEL",),
|
|
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
}}
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "advanced/model"
|
|
|
|
def patch(self, model, multiplier):
|
|
def rescale_cfg(args):
|
|
cond = args["cond"]
|
|
uncond = args["uncond"]
|
|
cond_scale = args["cond_scale"]
|
|
sigma = args["sigma"]
|
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
|
x_orig = args["input"]
|
|
|
|
#rescale cfg has to be done on v-pred model output
|
|
x = x_orig / (sigma * sigma + 1.0)
|
|
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
|
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
|
|
|
#rescalecfg
|
|
x_cfg = uncond + cond_scale * (cond - uncond)
|
|
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
|
|
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
|
|
|
|
x_rescaled = x_cfg * (ro_pos / ro_cfg)
|
|
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
|
|
|
|
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
|
|
|
|
m = model.clone()
|
|
m.set_model_sampler_cfg_function(rescale_cfg)
|
|
return (m, )
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
|
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
|
"RescaleCFG": RescaleCFG,
|
|
}
|