2022-09-03 09:08:45 +00:00
import torch
2023-08-08 15:35:31 +00:00
from modules import prompt_parser , devices , sd_samplers_common
2022-09-03 09:08:45 +00:00
2023-01-30 06:51:06 +00:00
from modules . shared import opts , state
2022-09-03 09:08:45 +00:00
import modules . shared as shared
2022-11-02 00:38:17 +00:00
from modules . script_callbacks import CFGDenoiserParams , cfg_denoiser_callback
2023-02-11 02:18:38 +00:00
from modules . script_callbacks import CFGDenoisedParams , cfg_denoised_callback
2023-05-14 01:49:41 +00:00
from modules . script_callbacks import AfterCFGCallbackParams , cfg_after_cfg_callback
2024-01-27 21:21:25 +00:00
from modules_forge import forge_sampler
2022-09-03 09:08:45 +00:00
2022-10-22 17:48:13 +00:00
2023-07-11 18:16:43 +00:00
def catenate_conds ( conds ) :
if not isinstance ( conds [ 0 ] , dict ) :
return torch . cat ( conds )
return { key : torch . cat ( [ x [ key ] for x in conds ] ) for key in conds [ 0 ] . keys ( ) }
def subscript_cond ( cond , a , b ) :
if not isinstance ( cond , dict ) :
return cond [ a : b ]
return { key : vec [ a : b ] for key , vec in cond . items ( ) }
def pad_cond ( tensor , repeats , empty ) :
if not isinstance ( tensor , dict ) :
return torch . cat ( [ tensor , empty . repeat ( ( tensor . shape [ 0 ] , repeats , 1 ) ) ] , axis = 1 )
tensor [ ' crossattn ' ] = pad_cond ( tensor [ ' crossattn ' ] , repeats , empty )
return tensor
2022-09-03 09:08:45 +00:00
class CFGDenoiser ( torch . nn . Module ) :
2023-01-30 07:11:30 +00:00
"""
Classifier free guidance denoiser . A wrapper for stable diffusion model ( specifically for unet )
that can take a noisy picture and produce a noise - free picture using two guidances ( prompts )
instead of one . Originally , the second prompt is just an empty string , but we use non - empty
negative prompt .
"""
2023-08-08 19:09:40 +00:00
def __init__ ( self , sampler ) :
2022-09-03 09:08:45 +00:00
super ( ) . __init__ ( )
2023-08-08 19:09:40 +00:00
self . model_wrap = None
2022-09-03 09:08:45 +00:00
self . mask = None
self . nmask = None
self . init_latent = None
2023-08-08 19:09:40 +00:00
self . steps = None
2023-08-12 09:39:59 +00:00
""" number of steps as specified by user in UI """
self . total_steps = None
""" expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler """
2022-09-15 10:10:16 +00:00
self . step = 0
2023-02-04 08:06:17 +00:00
self . image_cfg_scale = None
2023-06-27 03:18:43 +00:00
self . padded_cond_uncond = False
2023-08-08 16:20:11 +00:00
self . sampler = sampler
2023-08-08 19:09:40 +00:00
self . model_wrap = None
self . p = None
2023-11-28 23:10:22 +00:00
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
2023-08-14 05:59:15 +00:00
self . mask_before_denoising = False
2023-08-08 19:09:40 +00:00
@property
def inner_model ( self ) :
raise NotImplementedError ( )
2024-01-26 03:52:03 +00:00
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale , timestep , x_in , cond ) :
2022-12-24 15:38:16 +00:00
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
denoised = torch . clone ( denoised_uncond )
for i , conds in enumerate ( conds_list ) :
for cond_index , weight in conds :
denoised [ i ] + = ( x_out [ cond_index ] - denoised_uncond [ i ] ) * ( weight * cond_scale )
2024-01-27 21:21:25 +00:00
return denoised
2022-12-24 15:38:16 +00:00
2023-02-04 08:06:17 +00:00
def combine_denoised_for_edit_model ( self , x_out , cond_scale ) :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised = out_uncond + cond_scale * ( out_cond - out_img_cond ) + self . image_cfg_scale * ( out_img_cond - out_uncond )
return denoised
2023-08-08 16:20:11 +00:00
def get_pred_x0 ( self , x_in , x_out , sigma ) :
return x_out
2023-08-08 19:09:40 +00:00
def update_inner_model ( self ) :
self . model_wrap = None
c , uc = self . p . get_conds ( )
self . sampler . sampler_extra_args [ ' cond ' ] = c
self . sampler . sampler_extra_args [ ' uncond ' ] = uc
2023-03-28 22:18:28 +00:00
def forward ( self , x , sigma , uncond , cond , cond_scale , s_min_uncond , image_cond ) :
2022-10-18 14:23:38 +00:00
if state . interrupted or state . skipped :
2023-01-30 06:51:06 +00:00
raise sd_samplers_common . InterruptedException
2022-10-18 14:23:38 +00:00
2024-01-25 15:52:29 +00:00
if sd_samplers_common . apply_refiner ( self , x ) :
2023-08-08 19:09:40 +00:00
cond = self . sampler . sampler_extra_args [ ' cond ' ]
uncond = self . sampler . sampler_extra_args [ ' uncond ' ]
2023-02-04 08:06:17 +00:00
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared . sd_model . cond_stage_key == " edit " and self . image_cfg_scale is not None and self . image_cfg_scale != 1.0
2022-10-05 20:16:27 +00:00
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
2022-09-15 10:10:16 +00:00
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
2023-05-10 08:05:02 +00:00
assert not is_edit_model or all ( len ( conds ) == 1 for conds in conds_list ) , " AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0) "
2023-02-04 08:06:17 +00:00
2023-12-06 23:54:42 +00:00
# If we use masks, blending between the denoised and original latent images occurs here.
2023-12-07 04:16:27 +00:00
def apply_blend ( current_latent ) :
blended_latent = current_latent * self . nmask + self . init_latent * self . mask
if self . p . scripts is not None :
from modules import scripts
mba = scripts . MaskBlendArgs ( current_latent , self . nmask , self . init_latent , self . mask , blended_latent , denoiser = self , sigma = sigma )
self . p . scripts . on_mask_blend ( self . p , mba )
blended_latent = mba . blended_latent
return blended_latent
2023-12-06 23:54:42 +00:00
2023-11-28 23:10:22 +00:00
# Blend in the original latents (before)
2023-08-14 05:59:15 +00:00
if self . mask_before_denoising and self . mask is not None :
2023-12-06 23:54:42 +00:00
x = apply_blend ( x )
2023-08-08 16:20:11 +00:00
2024-01-27 21:21:25 +00:00
denoiser_params = CFGDenoiserParams ( x , image_cond , sigma , state . sampling_step , state . sampling_steps , tensor , uncond , self )
2022-11-02 00:38:17 +00:00
cfg_denoiser_callback ( denoiser_params )
2023-08-08 16:20:11 +00:00
2024-01-27 21:21:25 +00:00
denoised = forge_sampler . forge_sample ( self , denoiser_params = denoiser_params , cond_scale = cond_scale )
preview = self . sampler . last_latent = denoised
2023-08-08 16:20:11 +00:00
sd_samplers_common . store_latent ( preview )
2022-09-03 09:08:45 +00:00
2023-05-14 01:49:41 +00:00
after_cfg_callback_params = AfterCFGCallbackParams ( denoised , state . sampling_step , state . sampling_steps )
cfg_after_cfg_callback ( after_cfg_callback_params )
2023-05-14 05:15:22 +00:00
denoised = after_cfg_callback_params . x
2023-05-14 01:49:41 +00:00
2022-09-15 10:10:16 +00:00
self . step + = 1
2022-09-03 09:08:45 +00:00
return denoised