import gradio as gr from modules import scripts from ldm_patched.contrib.external_sag import SelfAttentionGuidance opSelfAttentionGuidance = SelfAttentionGuidance() class SAGForForge(scripts.Script): def title(self): return "SelfAttentionGuidance Integrated" def show(self, is_img2img): return scripts.AlwaysVisible def ui(self, *args, **kwargs): with gr.Accordion(open=False, label=self.title()): enabled = gr.Checkbox(label='Enabled', value=False) scale = gr.Slider(label='Scale', minimum=-2.0, maximum=5.0, step=0.01, value=0.5) blur_sigma = gr.Slider(label='Blur Sigma', minimum=0.0, maximum=10.0, step=0.01, value=2.0) return enabled, scale, blur_sigma def process_batch(self, p, *script_args, **kwargs): enabled, scale, blur_sigma = script_args if not enabled: return unet = p.sd_model.forge_objects.unet unet = opSelfAttentionGuidance.patch(unet, scale, blur_sigma)[0] p.sd_model.forge_objects.unet = unet p.extra_generation_params.update(dict( sag_enabled=enabled, sag_scale=scale, sag_blur_sigma=blur_sigma )) return