From 6a854fcb3809fb4eacb0726eb58d2f606bfa1f12 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sun, 11 Feb 2024 03:30:22 +0000 Subject: [PATCH] Add documentation on ControlNetUnit (#176) * remove dict from any * nit * nit --- .../lib_controlnet/external_code.py | 81 +++++++++++++++---- .../lib_controlnet/utils.py | 38 --------- .../sd_forge_controlnet/scripts/controlnet.py | 2 +- 3 files changed, 65 insertions(+), 56 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 95ff32cb..c48f238a 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Union, Tuple, Dict +from typing import List, Optional, Union, Dict, TypedDict import numpy as np from modules import shared from lib_controlnet.logging import logger @@ -141,41 +141,88 @@ def pixel_perfect_resolution( return int(np.round(estimation)) -InputImage = Union[np.ndarray, str] -InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] +class GradioImageMaskPair(TypedDict): + """Represents the dict object from Gradio's image component if `tool="sketch"` + is specified. + { + "image": np.ndarray, + "mask": np.ndarray, + } + """ + image: np.ndarray + mask: np.ndarray @dataclass -class UiControlNetUnit: +class ControlNetUnit: + """Represents an entire ControlNet processing unit. + + To add a new field to this class + ## If the new field can be specified on UI, you need to + - Add a new field of the same name in constructor of `ControlNetUiGroup` + - Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render` + as a Gradio `IOComponent`. + - Add the new `ControlNetUiGroup` field to `unit_args` in + `ControlNetUiGroup.render`. The order of parameters matters. + + ## If the new field needs to appear in infotext, you need to + - Add a new item in `ControlNetUnit.infotext_fields`. + API-only fields cannot appear in infotext. + """ + # Following fields should only be used in the UI. + # ====== Start of UI only fields ====== + # Specifies the input mode for the unit, defaulting to a simple mode. input_mode: InputMode = InputMode.SIMPLE + # Determines whether to use the preview image as input; defaults to False. use_preview_as_input: bool = False + # Directory path for batch processing of images. batch_image_dir: str = '' + # Directory path for batch processing of masks. batch_mask_dir: str = '' + # Optional list of gallery images for batch input; defaults to None. batch_input_gallery: Optional[List[str]] = None + # Optional list of gallery masks for batch processing; defaults to None. batch_mask_gallery: Optional[List[str]] = None + # Holds the preview image as a NumPy array; defaults to None. generated_image: Optional[np.ndarray] = None + # ====== End of UI only fields ====== + + # Following fields are used in both the API and the UI. + # Holds the mask image as a NumPy array; defaults to None. mask_image: Optional[np.ndarray] = None - # If hires fix is enabled in A1111, how should this ControlNet unit be applied. - # The value is ignored if the generation is not using hires fix. + # Specifies how this unit should be applied in each pass of high-resolution fix. + # Ignored if high-resolution fix is not enabled. hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH + # Indicates whether the unit is enabled; defaults to True. enabled: bool = True + # Name of the module being used; defaults to "None". module: str = "None" + # Name of the model being used; defaults to "None". model: str = "None" + # Weight of the unit in the overall processing; defaults to 1.0. weight: float = 1.0 - image: Optional[Union[InputImage, List[InputImage]]] = None + # Optional image for input; defaults to None. + image: Optional[GradioImageMaskPair] = None + # Specifies the mode of image resizing; defaults to inner fit. resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT + # Resolution for processing by the unit; defaults to -1 (unspecified). processor_res: int = -1 + # Threshold A for processing; defaults to -1 (unspecified). threshold_a: float = -1 + # Threshold B for processing; defaults to -1 (unspecified). threshold_b: float = -1 + # Start value for guidance; defaults to 0.0. guidance_start: float = 0.0 + # End value for guidance; defaults to 1.0. guidance_end: float = 1.0 + # Enables pixel-perfect processing; defaults to False. pixel_perfect: bool = False + # Control mode for the unit; defaults to balanced. control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + + # Following fields should only be used in the API. # ====== Start of API only fields ====== - # Whether save the detected map of this unit. Setting this option to False - # prevents saving the detected map or sending detected map along with - # generated images via API. Currently the option is only accessible in API - # calls. + # Whether to save the detected map for this unit; defaults to True. save_detected_map: bool = True # ====== End of API only fields ====== @@ -202,11 +249,11 @@ class UiControlNetUnit: ) @staticmethod - def from_dict(d: Dict) -> "UiControlNetUnit": - """Create UiControlNetUnit from dict. This is primarily used to convert - API json dict to UiControlNetUnit.""" - unit = UiControlNetUnit( - **{k: v for k, v in d.items() if k in vars(UiControlNetUnit)} + def from_dict(d: Dict) -> "ControlNetUnit": + """Create ControlNetUnit from dict. This is primarily used to convert + API json dict to ControlNetUnit.""" + unit = ControlNetUnit( + **{k: v for k, v in d.items() if k in vars(ControlNetUnit)} ) if isinstance(unit.image, str): img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8') @@ -220,7 +267,7 @@ class UiControlNetUnit: # Backward Compatible -ControlNetUnit = UiControlNetUnit +UiControlNetUnit = ControlNetUnit def to_base64_nparray(encoding: str): diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py index 2299581b..c911d1c4 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -190,44 +190,6 @@ def align_dim_latent(x: int) -> int: return (x // 8) * 8 -def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]: - if image is None: - return None - - if isinstance(image, (tuple, list)): - image = {'image': image[0], 'mask': image[1]} - elif not isinstance(image, dict): - image = {'image': image, 'mask': None} - else: # type(image) is dict - # copy to enable modifying the dict and prevent response serialization error - image = dict(image) - - if isinstance(image['image'], str): - if os.path.exists(image['image']): - image['image'] = np.array(Image.open(image['image'])).astype('uint8') - elif image['image']: - image['image'] = external_code.to_base64_nparray(image['image']) - else: - image['image'] = None - - # If there is no image, return image with None image and None mask - if image['image'] is None: - image['mask'] = None - return image - - if 'mask' not in image or image['mask'] is None: - image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) - elif isinstance(image['mask'], str): - if os.path.exists(image['mask']): - image['mask'] = np.array(Image.open(image['mask'])).astype('uint8') - elif image['mask']: - image['mask'] = external_code.to_base64_nparray(image['mask']) - else: - image['mask'] = np.zeros_like(image['image'], dtype=np.uint8) - - return image - - def prepare_mask( mask: Image.Image, p: processing.StableDiffusionProcessing ) -> Image.Image: diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index e7e228c6..f970dc8a 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -12,7 +12,7 @@ import gradio as gr from lib_controlnet import global_state, external_code from lib_controlnet.external_code import ControlNetUnit -from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \ +from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \ prepare_mask, judge_image_type from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from lib_controlnet.controlnet_ui.photopea import Photopea