import torch from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback from modules_forge import forge_sampler def catenate_conds(conds): if not isinstance(conds[0], dict): return torch.cat(conds) return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} def subscript_cond(cond, a, b): if not isinstance(cond, dict): return cond[a:b] return {key: vec[a:b] for key, vec in cond.items()} def pad_cond(tensor, repeats, empty): if not isinstance(tensor, dict): return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) return tensor class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) that can take a noisy picture and produce a noise-free picture using two guidances (prompts) instead of one. Originally, the second prompt is just an empty string, but we use non-empty negative prompt. """ def __init__(self, sampler): super().__init__() self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" self.total_steps = None """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.padded_cond_uncond_v0 = False self.sampler = sampler self.model_wrap = None self.p = None # Backward Compatibility self.mask_before_denoising = False self.classic_ddim_eps_estimation = False @property def inner_model(self): raise NotImplementedError() def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) for i, conds in enumerate(conds_list): for cond_index, weight in conds: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) return denoised def combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3) denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) return denoised def get_pred_x0(self, x_in, x_out, sigma): return x_out def update_inner_model(self): self.model_wrap = None c, uc = self.p.get_conds() self.sampler.sampler_extra_args['cond'] = c self.sampler.sampler_extra_args['uncond'] = uc def pad_cond_uncond(self, cond, uncond): empty = shared.sd_model.cond_stage_model_empty_prompt num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: cond = pad_cond(cond, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True return cond, uncond def pad_cond_uncond_v0(self, cond, uncond): """ Pads the 'uncond' tensor to match the shape of the 'cond' tensor. If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded. If 'uncond' is a tensor, it is padded directly. If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond' is repeated to match the number of columns in 'cond'. If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated to match the number of columns in 'cond'. Args: cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'. uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded. Returns: tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor. Note: This is the padding that was always used in DDIM before version 1.6.0 """ is_dict_cond = isinstance(uncond, dict) uncond_vec = uncond['crossattn'] if is_dict_cond else uncond if uncond_vec.shape[1] < cond.shape[1]: last_vector = uncond_vec[:, -1:] last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1]) uncond_vec = torch.hstack([uncond_vec, last_vector_repeated]) self.padded_cond_uncond_v0 = True elif uncond_vec.shape[1] > cond.shape[1]: uncond_vec = uncond_vec[:, :cond.shape[1]] self.padded_cond_uncond_v0 = True if is_dict_cond: uncond['crossattn'] = uncond_vec else: uncond = uncond_vec return cond, uncond def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException original_x_device = x.device original_x_dtype = x.dtype if self.classic_ddim_eps_estimation: acd = self.inner_model.inner_model.alphas_cumprod fake_sigmas = ((1 - acd) / acd) ** 0.5 real_sigma = fake_sigmas[sigma.round().long().clip(0, int(fake_sigmas.shape[0]))] real_sigma_data = 1.0 x = x * (((real_sigma ** 2.0 + real_sigma_data ** 2.0) ** 0.5)[:, None, None, None]) sigma = real_sigma if sd_samplers_common.apply_refiner(self, x): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if self.mask is not None: noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent) x = x * self.nmask + noisy_initial_latent * self.mask denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, cond, uncond, self) cfg_denoiser_callback(denoiser_params) denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale, cond_composition=cond_composition) if self.mask is not None: denoised = denoised * self.nmask + self.init_latent * self.mask preview = self.sampler.last_latent = denoised sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) denoised = after_cfg_callback_params.x self.step += 1 if self.classic_ddim_eps_estimation: eps = (x - denoised) / sigma[:, None, None, None] return eps return denoised.to(device=original_x_device, dtype=original_x_dtype)