ee023f4fbf
Fix UniPC data cast and shape broadcast in #184 This also fix potential problems in DDIM. The cause of this BUG is A1111’s `modules\models\diffusion\uni_pc\uni_pc.py` does not have a data cast and the Forge’s DDIM estimator forget to match the broadcast shape of sigmas. (At the same time when we are fixing this BUG in A1111’s very original and high-quality samplers, comfyanonymous is still believing that Forge is using comfyui to sample images, eg, ComfyUI UniPC. Comfyanonymous is so cute. See also the jokes here https://github.com/lllyasviel/stable-diffusion-webui-forge/discussions/169#discussioncomment-8428689)
203 lines
7.8 KiB
Python
203 lines
7.8 KiB
Python
import torch
|
|
from modules import prompt_parser, devices, sd_samplers_common
|
|
|
|
from modules.shared import opts, state
|
|
import modules.shared as shared
|
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
|
from modules_forge import forge_sampler
|
|
|
|
|
|
def catenate_conds(conds):
|
|
if not isinstance(conds[0], dict):
|
|
return torch.cat(conds)
|
|
|
|
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
|
|
|
|
|
def subscript_cond(cond, a, b):
|
|
if not isinstance(cond, dict):
|
|
return cond[a:b]
|
|
|
|
return {key: vec[a:b] for key, vec in cond.items()}
|
|
|
|
|
|
def pad_cond(tensor, repeats, empty):
|
|
if not isinstance(tensor, dict):
|
|
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
|
|
|
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
|
return tensor
|
|
|
|
|
|
class CFGDenoiser(torch.nn.Module):
|
|
"""
|
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
negative prompt.
|
|
"""
|
|
|
|
def __init__(self, sampler):
|
|
super().__init__()
|
|
self.model_wrap = None
|
|
self.mask = None
|
|
self.nmask = None
|
|
self.init_latent = None
|
|
self.steps = None
|
|
"""number of steps as specified by user in UI"""
|
|
|
|
self.total_steps = None
|
|
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
|
|
|
self.step = 0
|
|
self.image_cfg_scale = None
|
|
self.padded_cond_uncond = False
|
|
self.padded_cond_uncond_v0 = False
|
|
self.sampler = sampler
|
|
self.model_wrap = None
|
|
self.p = None
|
|
|
|
# Backward Compatibility
|
|
self.mask_before_denoising = False
|
|
|
|
self.classic_ddim_eps_estimation = False
|
|
|
|
@property
|
|
def inner_model(self):
|
|
raise NotImplementedError()
|
|
|
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond):
|
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
denoised = torch.clone(denoised_uncond)
|
|
|
|
for i, conds in enumerate(conds_list):
|
|
for cond_index, weight in conds:
|
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
|
|
return denoised
|
|
|
|
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
|
|
|
return denoised
|
|
|
|
def get_pred_x0(self, x_in, x_out, sigma):
|
|
return x_out
|
|
|
|
def update_inner_model(self):
|
|
self.model_wrap = None
|
|
|
|
c, uc = self.p.get_conds()
|
|
self.sampler.sampler_extra_args['cond'] = c
|
|
self.sampler.sampler_extra_args['uncond'] = uc
|
|
|
|
def pad_cond_uncond(self, cond, uncond):
|
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
|
|
|
|
if num_repeats < 0:
|
|
cond = pad_cond(cond, -num_repeats, empty)
|
|
self.padded_cond_uncond = True
|
|
elif num_repeats > 0:
|
|
uncond = pad_cond(uncond, num_repeats, empty)
|
|
self.padded_cond_uncond = True
|
|
|
|
return cond, uncond
|
|
|
|
def pad_cond_uncond_v0(self, cond, uncond):
|
|
"""
|
|
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
|
|
|
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
|
If 'uncond' is a tensor, it is padded directly.
|
|
|
|
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
|
is repeated to match the number of columns in 'cond'.
|
|
|
|
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
|
to match the number of columns in 'cond'.
|
|
|
|
Args:
|
|
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
|
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
|
|
|
Returns:
|
|
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
|
|
|
Note:
|
|
This is the padding that was always used in DDIM before version 1.6.0
|
|
"""
|
|
|
|
is_dict_cond = isinstance(uncond, dict)
|
|
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
|
|
|
if uncond_vec.shape[1] < cond.shape[1]:
|
|
last_vector = uncond_vec[:, -1:]
|
|
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
|
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
|
self.padded_cond_uncond_v0 = True
|
|
elif uncond_vec.shape[1] > cond.shape[1]:
|
|
uncond_vec = uncond_vec[:, :cond.shape[1]]
|
|
self.padded_cond_uncond_v0 = True
|
|
|
|
if is_dict_cond:
|
|
uncond['crossattn'] = uncond_vec
|
|
else:
|
|
uncond = uncond_vec
|
|
|
|
return cond, uncond
|
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
if state.interrupted or state.skipped:
|
|
raise sd_samplers_common.InterruptedException
|
|
|
|
original_x_device = x.device
|
|
original_x_dtype = x.dtype
|
|
|
|
if self.classic_ddim_eps_estimation:
|
|
acd = self.inner_model.inner_model.alphas_cumprod
|
|
fake_sigmas = ((1 - acd) / acd) ** 0.5
|
|
real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))]
|
|
real_sigma_data = 1.0
|
|
x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None])
|
|
sigma = real_sigma
|
|
|
|
if sd_samplers_common.apply_refiner(self, x):
|
|
cond = self.sampler.sampler_extra_args['cond']
|
|
uncond = self.sampler.sampler_extra_args['uncond']
|
|
|
|
cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
|
|
if self.mask is not None:
|
|
noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)
|
|
x = x * self.nmask + noisy_initial_latent * self.mask
|
|
|
|
denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self)
|
|
cfg_denoiser_callback(denoiser_params)
|
|
|
|
denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params,
|
|
cond_scale=cond_scale, cond_composition=cond_composition)
|
|
|
|
if self.mask is not None:
|
|
denoised = denoised * self.nmask + self.init_latent * self.mask
|
|
|
|
preview = self.sampler.last_latent = denoised
|
|
sd_samplers_common.store_latent(preview)
|
|
|
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
|
denoised = after_cfg_callback_params.x
|
|
|
|
self.step += 1
|
|
|
|
if self.classic_ddim_eps_estimation:
|
|
eps = (x - denoised) / sigma[:, None, None, None]
|
|
return eps
|
|
|
|
return denoised.to(device=original_x_device, dtype=original_x_dtype)
|
|
|