import torch from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback from modules_forge import forge_sampler def catenate_conds(conds): if not isinstance(conds[0], dict): return torch.cat(conds) return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} def subscript_cond(cond, a, b): if not isinstance(cond, dict): return cond[a:b] return {key: vec[a:b] for key, vec in cond.items()} def pad_cond(tensor, repeats, empty): if not isinstance(tensor, dict): return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) return tensor class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) that can take a noisy picture and produce a noise-free picture using two guidances (prompts) instead of one. Originally, the second prompt is just an empty string, but we use non-empty negative prompt. """ def __init__(self, sampler): super().__init__() self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" self.total_steps = None """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler self.model_wrap = None self.p = None # NOTE: masking before denoising can cause the original latents to be oversmoothed # as the original latents do not have noise self.mask_before_denoising = False @property def inner_model(self): raise NotImplementedError() def combine_denoised(self, x_out, conds_list, uncond, cond_scale, timestep, x_in, cond): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) for i, conds in enumerate(conds_list): for cond_index, weight in conds: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) return denoised def combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3) denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) return denoised def get_pred_x0(self, x_in, x_out, sigma): return x_out def update_inner_model(self): self.model_wrap = None c, uc = self.p.get_conds() self.sampler.sampler_extra_args['cond'] = c self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException if sd_samplers_common.apply_refiner(self, x): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # If we use masks, blending between the denoised and original latent images occurs here. def apply_blend(current_latent): blended_latent = current_latent * self.nmask + self.init_latent * self.mask if self.p.scripts is not None: from modules import scripts mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) self.p.scripts.on_mask_blend(self.p, mba) blended_latent = mba.blended_latent return blended_latent # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: x = apply_blend(x) denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, tensor, uncond, self) cfg_denoiser_callback(denoiser_params) denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale) preview = self.sampler.last_latent = denoised sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) denoised = after_cfg_callback_params.x self.step += 1 return denoised