import os from typing import Dict, Optional, Tuple, List, Union import torch import modules.scripts as scripts from modules import shared, script_callbacks, processing, masking, images from modules.api.api import decode_base64_to_image import gradio as gr from lib_controlnet import global_state, external_code from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \ prepare_mask, judge_image_type from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.logging import logger from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \ StableDiffusionProcessing from lib_controlnet.infotext import Infotext from modules_forge.forge_util import HWC3, numpy_to_pytorch import numpy as np import functools from PIL import Image from modules_forge.shared import try_load_supported_control_model from modules_forge.supported_controlnet import ControlModelPatcher # Gradio 3.32 bug fix import tempfile gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio') os.makedirs(gradio_tempfile_path, exist_ok=True) global_state.update_controlnet_filenames() @functools.lru_cache(maxsize=shared.opts.data.get("control_net_model_cache_size", 5)) def cached_controlnet_loader(filename): return try_load_supported_control_model(filename) class ControlNetCachedParameters: def __init__(self): self.preprocessor = None self.model = None self.control_cond = None self.control_cond_for_hr_fix = None class ControlNetForForgeOfficial(scripts.Script): def title(self): return "ControlNet" def show(self, is_img2img): return scripts.AlwaysVisible def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]: default_unit = UiControlNetUnit(enabled=False, module="None", model="None") group = ControlNetUiGroup(is_img2img, default_unit, photopea) return group, group.render(tabname, elem_id_tabname) def ui(self, is_img2img): infotext = Infotext() ui_groups = [] controls = [] max_models = shared.opts.data.get("control_net_unit_count", 3) elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet" with gr.Group(elem_id=elem_id_tabname): with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet"): photopea = Photopea() if not shared.opts.data.get("controlnet_disable_photopea_edit", False) else None if max_models > 1: with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"): for i in range(max_models): with gr.Tab(f"ControlNet Unit {i}", elem_classes=['cnet-unit-tab']): group, state = self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname, photopea) ui_groups.append(group) controls.append(state) else: with gr.Column(): group, state = self.uigroup(f"ControlNet", is_img2img, elem_id_tabname, photopea) ui_groups.append(group) controls.append(state) for i, ui_group in enumerate(ui_groups): infotext.register_unit(i, ui_group) if shared.opts.data.get("control_net_sync_field_args", True): self.infotext_fields = infotext.infotext_fields self.paste_field_names = infotext.paste_field_names return tuple(controls) def get_enabled_units(self, p): units = external_code.get_all_units_in_processing(p) enabled_units = [x for x in units if x.enabled] return enabled_units def choose_input_image( self, p: processing.StableDiffusionProcessing, unit: external_code.ControlNetUnit, ) -> Tuple[np.ndarray, external_code.ResizeMode]: """ Choose input image from following sources with descending priority: - p.image_control: [Deprecated] Lagacy way to pass image to controlnet. - p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet. - unit.image: ControlNet tab input image. - p.init_images: A1111 img2img tab input image. Returns: - The input image in ndarray form. - The resize mode. """ def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[ List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]: unit_has_multiple_images = ( isinstance(unit.image, list) and len(unit.image) > 0 and "image" in unit.image[0] ) if unit_has_multiple_images: return [ d for img in unit.image for d in (image_dict_from_any(img),) if d is not None ] return image_dict_from_any(unit.image) def decode_image(img) -> np.ndarray: """Need to check the image for API compatibility.""" if isinstance(img, str): return np.asarray(decode_base64_to_image(image['image'])) else: assert isinstance(img, np.ndarray) return img # 4 input image sources. image = parse_unit_image(unit) a1111_image = getattr(p, "init_images", [None])[0] resize_mode = external_code.resize_mode_from_value(unit.resize_mode) if image is not None: if isinstance(image, list): # Add mask logic if later there is a processor that accepts mask # on multiple inputs. input_image = [HWC3(decode_image(img['image'])) for img in image] else: input_image = HWC3(decode_image(image['image'])) if 'mask' in image and image['mask'] is not None: while len(image['mask'].shape) < 3: image['mask'] = image['mask'][..., np.newaxis] if 'inpaint' in unit.module: logger.info("using inpaint as input") color = HWC3(image['image']) alpha = image['mask'][:, :, 0:1] input_image = np.concatenate([color, alpha], axis=2) elif ( not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and # There is wield gradio issue that would produce mask that is # not pure color when no scribble is made on canvas. # See https://github.com/Mikubill/sd-webui-controlnet/issues/1638. not ( (image['mask'][:, :, 0] <= 5).all() or (image['mask'][:, :, 0] >= 250).all() ) ): logger.info("using mask as input") input_image = HWC3(image['mask'][:, :, 0]) unit.module = 'none' # Always use black bg and white line elif a1111_image is not None: input_image = HWC3(np.asarray(a1111_image)) a1111_i2i_resize_mode = getattr(p, "resize_mode", None) assert a1111_i2i_resize_mode is not None resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode) a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None) if 'inpaint' in unit.module: if a1111_mask_image is not None: a1111_mask = np.array(prepare_mask(a1111_mask_image, p)) assert a1111_mask.ndim == 2 assert a1111_mask.shape[0] == input_image.shape[0] assert a1111_mask.shape[1] == input_image.shape[1] input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2) else: input_image = np.concatenate([ input_image[:, :, 0:3], np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1], ], axis=2) else: raise ValueError("controlnet is enabled but no input image is given") assert isinstance(input_image, (np.ndarray, list)) return input_image, resize_mode @staticmethod def try_crop_image_with_a1111_mask( p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, input_image: np.ndarray, resize_mode: external_code.ResizeMode, preprocessor ) -> np.ndarray: """ Crop ControlNet input image based on A1111 inpaint mask given. This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res to crop tiles. """ # Note: The method determining whether the active script is an upscale script is purely # based on `extra_generation_params` these scripts attach on `p`, and subject to change # in the future. # TODO: Change this to a more robust condition once A1111 offers a way to verify script name. is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys()) logger.debug(f"is_upscale_script={is_upscale_script}") # Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked" # option is selected. a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None) is_only_masked_inpaint = ( issubclass(type(p), StableDiffusionProcessingImg2Img) and p.inpaint_full_res and a1111_mask_image is not None ) if ( preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab and is_only_masked_inpaint and (is_upscale_script or unit.inpaint_crop_input_image) ): logger.debug("Crop input image based on A1111 mask.") input_image = [input_image[:, :, i] for i in range(input_image.shape[2])] input_image = [Image.fromarray(x) for x in input_image] mask = prepare_mask(a1111_mask_image, p) crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height) input_image = [ images.resize_image(resize_mode.int_value(), i, mask.width, mask.height) for i in input_image ] input_image = [x.crop(crop_region) for x in input_image] input_image = [ images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height) for x in input_image ] input_image = [np.asarray(x)[:, :, 0] for x in input_image] input_image = np.stack(input_image, axis=2) return input_image @staticmethod def bound_check_params(unit: external_code.ControlNetUnit) -> None: """ Checks and corrects negative parameters in ControlNetUnit 'unit'. Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to their default values if negative. Args: unit (external_code.ControlNetUnit): The ControlNetUnit instance to check. """ preprocessor = global_state.get_preprocessor(unit.module) if unit.processor_res < 0: unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512)) if unit.threshold_a < 0: unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0)) if unit.threshold_b < 0: unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0)) return @staticmethod def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]: """Returns (h, w, hr_h, hr_w).""" h = align_dim_latent(p.height) w = align_dim_latent(p.width) high_res_fix = ( isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) ) if high_res_fix: if p.hr_resize_x == 0 and p.hr_resize_y == 0: hr_y = int(p.height * p.hr_scale) hr_x = int(p.width * p.hr_scale) else: hr_y, hr_x = p.hr_resize_y, p.hr_resize_x hr_y = align_dim_latent(hr_y) hr_x = align_dim_latent(hr_x) else: hr_y = h hr_x = w return h, w, hr_y, hr_x def get_input_data(self, p, unit, preprocessor): mask = None input_image, resize_mode = self.choose_input_image(p, unit) assert isinstance(input_image, np.ndarray), 'Invalid input image!' input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor) input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy return input_image, mask, resize_mode @torch.no_grad() def process_unit_after_click_generate(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): h, w, hr_y, hr_x = self.get_target_dimensions(p) has_high_res_fix = ( isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False) ) preprocessor = global_state.get_preprocessor(unit.module) input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor) if unit.pixel_perfect: unit.processor_res = external_code.pixel_perfect_resolution( input_image, target_H=h, target_W=w, resize_mode=resize_mode, ) seed = set_numpy_seed(p) logger.debug(f"Use numpy seed {seed}.") logger.info(f"Using preprocessor: {unit.module}") logger.info(f'preprocessor resolution = {unit.processor_res}') preprocessor_output = preprocessor( input_image=input_image, resolution=unit.processor_res, slider_1=unit.threshold_a, slider_2=unit.threshold_b, ) preprocessor_output_is_image = judge_image_type(preprocessor_output) if preprocessor_output_is_image: params.control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond)) params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1) if has_high_res_fix: params.control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) p.extra_result_images.append(external_code.visualize_inpaint_mask(params.control_cond_for_hr_fix)) params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1) else: params.control_cond_for_hr_fix = params.control_cond else: params.control_cond = preprocessor_output params.control_cond_for_hr_fix = preprocessor_output p.extra_result_images.append(input_image) if preprocessor.do_not_need_model: model_filename = 'Not Needed' params.model = ControlModelPatcher() else: model_filename = global_state.get_controlnet_filename(unit.model) params.model = cached_controlnet_loader(model_filename) assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}") params.preprocessor = preprocessor params.preprocessor.process_after_running_preprocessors(process=p, params=params, **kwargs) params.model.process_after_running_preprocessors(process=p, params=params, **kwargs) logger.info(f"Current ControlNet {type(params.model).__name__}: {model_filename}") return @torch.no_grad() def process_unit_before_every_sampling(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): is_hr_pass = getattr(p, 'is_hr_pass', False) if is_hr_pass: cond = params.control_cond_for_hr_fix else: cond = params.control_cond kwargs.update(dict(unit=unit, params=params)) # CN inpaint fix if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4: kwargs['cond_before_inpaint_fix'] = cond.clone() cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:] params.model.strength = float(unit.weight) params.model.start_percent = float(unit.guidance_start) params.model.end_percent = float(unit.guidance_end) params.model.positive_advanced_weighting = None params.model.negative_advanced_weighting = None params.model.advanced_frame_weighting = None params.model.advanced_sigma_weighting = None soft_weighting = { 'input': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736, 0.214600924414215, 0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625, 0.6806249999999999, 0.825], 'middle': [1.0], 'output': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736, 0.214600924414215, 0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625, 0.6806249999999999, 0.825] } zero_weighting = { 'input': [0.0] * 12, 'middle': [0.0], 'output': [0.0] * 12 } if unit.control_mode == external_code.ControlMode.CONTROL.value: params.model.positive_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = zero_weighting.copy() # high-ref fix pass always use softer injections if is_hr_pass or unit.control_mode == external_code.ControlMode.PROMPT.value: params.model.positive_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = soft_weighting.copy() params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs) params.model.process_before_every_sampling(p, cond, *args, **kwargs) logger.info(f"ControlNet Method {params.preprocessor.name} patched.") return @torch.no_grad() def process_unit_after_every_sampling(self, p: StableDiffusionProcessing, unit: external_code.ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): params.preprocessor.process_after_every_sampling(p, params, *args, **kwargs) params.model.process_after_every_sampling(p, params, *args, **kwargs) return def process(self, p, *args, **kwargs): self.current_params = {} for i, unit in enumerate(self.get_enabled_units(p)): self.bound_check_params(unit) params = ControlNetCachedParameters() self.process_unit_after_click_generate(p, unit, params, *args, **kwargs) self.current_params[i] = params return def process_before_every_sampling(self, p, *args, **kwargs): for i, unit in enumerate(self.get_enabled_units(p)): self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs) return def postprocess_batch_list(self, p, *args, **kwargs): for i, unit in enumerate(self.get_enabled_units(p)): self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs) self.current_params = {} return def on_ui_settings(): section = ('control_net', "ControlNet") shared.opts.add_option("control_net_detectedmap_dir", shared.OptionInfo( "detected_maps", "Directory for detected maps auto saving", section=section)) shared.opts.add_option("control_net_models_path", shared.OptionInfo( "", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section)) shared.opts.add_option("control_net_modules_path", shared.OptionInfo( "", "Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)", section=section)) shared.opts.add_option("control_net_unit_count", shared.OptionInfo( 3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section)) shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo( 5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section)) shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo( False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo( False, "Allow detectmap auto saving", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo( False, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo( True, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo( False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo( False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo( False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_disable_photopea_edit", shared.OptionInfo( False, "Disable photopea edit", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_photopea_warning", shared.OptionInfo( True, "Photopea popup warning", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_ignore_noninpaint_mask", shared.OptionInfo( False, "Ignore mask on ControlNet input image if control type is not inpaint", gr.Checkbox, {"interactive": True}, section=section)) shared.opts.add_option("controlnet_clip_detector_on_cpu", shared.OptionInfo( False, "Load CLIP preprocessor model on CPU", gr.Checkbox, {"interactive": True}, section=section)) script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted) script_callbacks.on_after_component(ControlNetUiGroup.on_after_component) script_callbacks.on_before_reload(ControlNetUiGroup.reset)