Add documentation on ControlNetUnit (#176)

* remove dict from any

* nit

* nit
This commit is contained in:
Chenlei Hu 2024-02-11 03:30:22 +00:00 committed by GitHub
parent ee023f4fbf
commit 6a854fcb38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 65 additions and 56 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union, Tuple, Dict
from typing import List, Optional, Union, Dict, TypedDict
import numpy as np
from modules import shared
from lib_controlnet.logging import logger
@ -141,41 +141,88 @@ def pixel_perfect_resolution(
return int(np.round(estimation))
InputImage = Union[np.ndarray, str]
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
class GradioImageMaskPair(TypedDict):
"""Represents the dict object from Gradio's image component if `tool="sketch"`
is specified.
{
"image": np.ndarray,
"mask": np.ndarray,
}
"""
image: np.ndarray
mask: np.ndarray
@dataclass
class UiControlNetUnit:
class ControlNetUnit:
"""Represents an entire ControlNet processing unit.
To add a new field to this class
## If the new field can be specified on UI, you need to
- Add a new field of the same name in constructor of `ControlNetUiGroup`
- Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render`
as a Gradio `IOComponent`.
- Add the new `ControlNetUiGroup` field to `unit_args` in
`ControlNetUiGroup.render`. The order of parameters matters.
## If the new field needs to appear in infotext, you need to
- Add a new item in `ControlNetUnit.infotext_fields`.
API-only fields cannot appear in infotext.
"""
# Following fields should only be used in the UI.
# ====== Start of UI only fields ======
# Specifies the input mode for the unit, defaulting to a simple mode.
input_mode: InputMode = InputMode.SIMPLE
# Determines whether to use the preview image as input; defaults to False.
use_preview_as_input: bool = False
# Directory path for batch processing of images.
batch_image_dir: str = ''
# Directory path for batch processing of masks.
batch_mask_dir: str = ''
# Optional list of gallery images for batch input; defaults to None.
batch_input_gallery: Optional[List[str]] = None
# Optional list of gallery masks for batch processing; defaults to None.
batch_mask_gallery: Optional[List[str]] = None
# Holds the preview image as a NumPy array; defaults to None.
generated_image: Optional[np.ndarray] = None
# ====== End of UI only fields ======
# Following fields are used in both the API and the UI.
# Holds the mask image as a NumPy array; defaults to None.
mask_image: Optional[np.ndarray] = None
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
# Specifies how this unit should be applied in each pass of high-resolution fix.
# Ignored if high-resolution fix is not enabled.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
# Indicates whether the unit is enabled; defaults to True.
enabled: bool = True
# Name of the module being used; defaults to "None".
module: str = "None"
# Name of the model being used; defaults to "None".
model: str = "None"
# Weight of the unit in the overall processing; defaults to 1.0.
weight: float = 1.0
image: Optional[Union[InputImage, List[InputImage]]] = None
# Optional image for input; defaults to None.
image: Optional[GradioImageMaskPair] = None
# Specifies the mode of image resizing; defaults to inner fit.
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
# Resolution for processing by the unit; defaults to -1 (unspecified).
processor_res: int = -1
# Threshold A for processing; defaults to -1 (unspecified).
threshold_a: float = -1
# Threshold B for processing; defaults to -1 (unspecified).
threshold_b: float = -1
# Start value for guidance; defaults to 0.0.
guidance_start: float = 0.0
# End value for guidance; defaults to 1.0.
guidance_end: float = 1.0
# Enables pixel-perfect processing; defaults to False.
pixel_perfect: bool = False
# Control mode for the unit; defaults to balanced.
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# Following fields should only be used in the API.
# ====== Start of API only fields ======
# Whether save the detected map of this unit. Setting this option to False
# prevents saving the detected map or sending detected map along with
# generated images via API. Currently the option is only accessible in API
# calls.
# Whether to save the detected map for this unit; defaults to True.
save_detected_map: bool = True
# ====== End of API only fields ======
@ -202,11 +249,11 @@ class UiControlNetUnit:
)
@staticmethod
def from_dict(d: Dict) -> "UiControlNetUnit":
"""Create UiControlNetUnit from dict. This is primarily used to convert
API json dict to UiControlNetUnit."""
unit = UiControlNetUnit(
**{k: v for k, v in d.items() if k in vars(UiControlNetUnit)}
def from_dict(d: Dict) -> "ControlNetUnit":
"""Create ControlNetUnit from dict. This is primarily used to convert
API json dict to ControlNetUnit."""
unit = ControlNetUnit(
**{k: v for k, v in d.items() if k in vars(ControlNetUnit)}
)
if isinstance(unit.image, str):
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
@ -220,7 +267,7 @@ class UiControlNetUnit:
# Backward Compatible
ControlNetUnit = UiControlNetUnit
UiControlNetUnit = ControlNetUnit
def to_base64_nparray(encoding: str):

View File

@ -190,44 +190,6 @@ def align_dim_latent(x: int) -> int:
return (x // 8) * 8
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
if image is None:
return None
if isinstance(image, (tuple, list)):
image = {'image': image[0], 'mask': image[1]}
elif not isinstance(image, dict):
image = {'image': image, 'mask': None}
else: # type(image) is dict
# copy to enable modifying the dict and prevent response serialization error
image = dict(image)
if isinstance(image['image'], str):
if os.path.exists(image['image']):
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
elif image['image']:
image['image'] = external_code.to_base64_nparray(image['image'])
else:
image['image'] = None
# If there is no image, return image with None image and None mask
if image['image'] is None:
image['mask'] = None
return image
if 'mask' not in image or image['mask'] is None:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
elif isinstance(image['mask'], str):
if os.path.exists(image['mask']):
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
elif image['mask']:
image['mask'] = external_code.to_base64_nparray(image['mask'])
else:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
return image
def prepare_mask(
mask: Image.Image, p: processing.StableDiffusionProcessing
) -> Image.Image:

View File

@ -12,7 +12,7 @@ import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea