my-sd/modules/sd_samplers_cfg_denoiser.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

202 lines
7.7 KiB
Python
Raw Normal View History

import torch
from modules import prompt_parser, sd_samplers_common
from modules.shared import state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
2024-01-27 21:21:25 +00:00
from modules_forge import forge_sampler
2023-07-11 18:16:43 +00:00
def catenate_conds(conds):
if not isinstance(conds[0], dict):
return torch.cat(conds)
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
def subscript_cond(cond, a, b):
if not isinstance(cond, dict):
return cond[a:b]
return {key: vec[a:b] for key, vec in cond.items()}
def pad_cond(tensor, repeats, empty):
if not isinstance(tensor, dict):
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
return tensor
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
2023-08-08 19:09:40 +00:00
def __init__(self, sampler):
super().__init__()
2023-08-08 19:09:40 +00:00
self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
2023-08-08 19:09:40 +00:00
self.steps = None
"""number of steps as specified by user in UI"""
self.total_steps = None
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
2022-09-15 10:10:16 +00:00
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
2024-01-27 19:30:12 +00:00
self.padded_cond_uncond_v0 = False
self.sampler = sampler
2023-08-08 19:09:40 +00:00
self.model_wrap = None
self.p = None
2024-02-04 21:47:31 +00:00
# Backward Compatibility
self.mask_before_denoising = False
2023-08-08 19:09:40 +00:00
2024-02-04 21:47:31 +00:00
self.classic_ddim_eps_estimation = False
2023-08-08 19:09:40 +00:00
@property
def inner_model(self):
raise NotImplementedError()
2024-01-26 03:52:03 +00:00
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def get_pred_x0(self, x_in, x_out, sigma):
return x_out
2023-08-08 19:09:40 +00:00
def update_inner_model(self):
self.model_wrap = None
c, uc = self.p.get_conds()
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
2024-01-27 19:30:12 +00:00
def pad_cond_uncond(self, cond, uncond):
empty = shared.sd_model.cond_stage_model_empty_prompt
2024-01-29 07:20:27 +00:00
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
2024-01-27 19:30:12 +00:00
if num_repeats < 0:
cond = pad_cond(cond, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
return cond, uncond
def pad_cond_uncond_v0(self, cond, uncond):
"""
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
If 'uncond' is a tensor, it is padded directly.
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
is repeated to match the number of columns in 'cond'.
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
to match the number of columns in 'cond'.
Args:
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
Returns:
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
Note:
This is the padding that was always used in DDIM before version 1.6.0
"""
is_dict_cond = isinstance(uncond, dict)
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
if uncond_vec.shape[1] < cond.shape[1]:
last_vector = uncond_vec[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
self.padded_cond_uncond_v0 = True
elif uncond_vec.shape[1] > cond.shape[1]:
uncond_vec = uncond_vec[:, :cond.shape[1]]
self.padded_cond_uncond_v0 = True
if is_dict_cond:
uncond['crossattn'] = uncond_vec
else:
uncond = uncond_vec
return cond, uncond
2023-03-28 22:18:28 +00:00
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
2024-02-07 01:46:23 +00:00
original_x_device = x.device
original_x_dtype = x.dtype
2024-02-04 21:47:31 +00:00
if self.classic_ddim_eps_estimation:
2024-02-04 10:03:35 +00:00
acd = self.inner_model.inner_model.alphas_cumprod
fake_sigmas = ((1 - acd) / acd) ** 0.5
2024-02-04 11:34:20 +00:00
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
2024-02-04 10:03:35 +00:00
real_sigma_data = 1.0
x = x * (real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5
sigma = real_sigma
2024-01-25 15:52:29 +00:00
if sd_samplers_common.apply_refiner(self, x):
2023-08-08 19:09:40 +00:00
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
2024-01-28 02:38:11 +00:00
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
2022-09-15 10:10:16 +00:00
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
2024-02-05 21:44:31 +00:00
if self.mask is not None:
2024-02-07 06:57:53 +00:00
noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)
2024-02-05 21:44:31 +00:00
x = x * self.nmask + noisy_initial_latent * self.mask
2024-01-27 21:53:18 +00:00
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
cfg_denoiser_callback(denoiser_params)
2024-01-28 02:38:11 +00:00
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params,
cond_scale=cond_scale, cond_composition=cond_composition)
2024-02-04 11:34:20 +00:00
if self.mask is not None:
2024-02-05 21:44:31 +00:00
denoised = denoised * self.nmask + self.init_latent * self.mask
2024-02-04 08:49:40 +00:00
2024-01-27 21:21:25 +00:00
preview = self.sampler.last_latent = denoised
sd_samplers_common.store_latent(preview)
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
denoised = after_cfg_callback_params.x
2022-09-15 10:10:16 +00:00
self.step += 1
2024-02-04 10:03:35 +00:00
2024-02-04 21:47:31 +00:00
if self.classic_ddim_eps_estimation:
2024-02-04 10:03:35 +00:00
eps = (x - denoised) / sigma
return eps
2024-02-07 01:46:23 +00:00
return denoised.to(device=original_x_device, dtype=original_x_dtype)