Add documentation on ControlNetUnit (#176)
* remove dict from any * nit * nit
This commit is contained in:
parent
ee023f4fbf
commit
6a854fcb38
@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union, Tuple, Dict
|
||||
from typing import List, Optional, Union, Dict, TypedDict
|
||||
import numpy as np
|
||||
from modules import shared
|
||||
from lib_controlnet.logging import logger
|
||||
@ -141,41 +141,88 @@ def pixel_perfect_resolution(
|
||||
return int(np.round(estimation))
|
||||
|
||||
|
||||
InputImage = Union[np.ndarray, str]
|
||||
InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage]
|
||||
class GradioImageMaskPair(TypedDict):
|
||||
"""Represents the dict object from Gradio's image component if `tool="sketch"`
|
||||
is specified.
|
||||
{
|
||||
"image": np.ndarray,
|
||||
"mask": np.ndarray,
|
||||
}
|
||||
"""
|
||||
image: np.ndarray
|
||||
mask: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class UiControlNetUnit:
|
||||
class ControlNetUnit:
|
||||
"""Represents an entire ControlNet processing unit.
|
||||
|
||||
To add a new field to this class
|
||||
## If the new field can be specified on UI, you need to
|
||||
- Add a new field of the same name in constructor of `ControlNetUiGroup`
|
||||
- Initialize the new `ControlNetUiGroup` field in `ControlNetUiGroup.render`
|
||||
as a Gradio `IOComponent`.
|
||||
- Add the new `ControlNetUiGroup` field to `unit_args` in
|
||||
`ControlNetUiGroup.render`. The order of parameters matters.
|
||||
|
||||
## If the new field needs to appear in infotext, you need to
|
||||
- Add a new item in `ControlNetUnit.infotext_fields`.
|
||||
API-only fields cannot appear in infotext.
|
||||
"""
|
||||
# Following fields should only be used in the UI.
|
||||
# ====== Start of UI only fields ======
|
||||
# Specifies the input mode for the unit, defaulting to a simple mode.
|
||||
input_mode: InputMode = InputMode.SIMPLE
|
||||
# Determines whether to use the preview image as input; defaults to False.
|
||||
use_preview_as_input: bool = False
|
||||
# Directory path for batch processing of images.
|
||||
batch_image_dir: str = ''
|
||||
# Directory path for batch processing of masks.
|
||||
batch_mask_dir: str = ''
|
||||
# Optional list of gallery images for batch input; defaults to None.
|
||||
batch_input_gallery: Optional[List[str]] = None
|
||||
# Optional list of gallery masks for batch processing; defaults to None.
|
||||
batch_mask_gallery: Optional[List[str]] = None
|
||||
# Holds the preview image as a NumPy array; defaults to None.
|
||||
generated_image: Optional[np.ndarray] = None
|
||||
# ====== End of UI only fields ======
|
||||
|
||||
# Following fields are used in both the API and the UI.
|
||||
# Holds the mask image as a NumPy array; defaults to None.
|
||||
mask_image: Optional[np.ndarray] = None
|
||||
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
|
||||
# The value is ignored if the generation is not using hires fix.
|
||||
# Specifies how this unit should be applied in each pass of high-resolution fix.
|
||||
# Ignored if high-resolution fix is not enabled.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
# Indicates whether the unit is enabled; defaults to True.
|
||||
enabled: bool = True
|
||||
# Name of the module being used; defaults to "None".
|
||||
module: str = "None"
|
||||
# Name of the model being used; defaults to "None".
|
||||
model: str = "None"
|
||||
# Weight of the unit in the overall processing; defaults to 1.0.
|
||||
weight: float = 1.0
|
||||
image: Optional[Union[InputImage, List[InputImage]]] = None
|
||||
# Optional image for input; defaults to None.
|
||||
image: Optional[GradioImageMaskPair] = None
|
||||
# Specifies the mode of image resizing; defaults to inner fit.
|
||||
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
|
||||
# Resolution for processing by the unit; defaults to -1 (unspecified).
|
||||
processor_res: int = -1
|
||||
# Threshold A for processing; defaults to -1 (unspecified).
|
||||
threshold_a: float = -1
|
||||
# Threshold B for processing; defaults to -1 (unspecified).
|
||||
threshold_b: float = -1
|
||||
# Start value for guidance; defaults to 0.0.
|
||||
guidance_start: float = 0.0
|
||||
# End value for guidance; defaults to 1.0.
|
||||
guidance_end: float = 1.0
|
||||
# Enables pixel-perfect processing; defaults to False.
|
||||
pixel_perfect: bool = False
|
||||
# Control mode for the unit; defaults to balanced.
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
|
||||
# Following fields should only be used in the API.
|
||||
# ====== Start of API only fields ======
|
||||
# Whether save the detected map of this unit. Setting this option to False
|
||||
# prevents saving the detected map or sending detected map along with
|
||||
# generated images via API. Currently the option is only accessible in API
|
||||
# calls.
|
||||
# Whether to save the detected map for this unit; defaults to True.
|
||||
save_detected_map: bool = True
|
||||
# ====== End of API only fields ======
|
||||
|
||||
@ -202,11 +249,11 @@ class UiControlNetUnit:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict) -> "UiControlNetUnit":
|
||||
"""Create UiControlNetUnit from dict. This is primarily used to convert
|
||||
API json dict to UiControlNetUnit."""
|
||||
unit = UiControlNetUnit(
|
||||
**{k: v for k, v in d.items() if k in vars(UiControlNetUnit)}
|
||||
def from_dict(d: Dict) -> "ControlNetUnit":
|
||||
"""Create ControlNetUnit from dict. This is primarily used to convert
|
||||
API json dict to ControlNetUnit."""
|
||||
unit = ControlNetUnit(
|
||||
**{k: v for k, v in d.items() if k in vars(ControlNetUnit)}
|
||||
)
|
||||
if isinstance(unit.image, str):
|
||||
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
|
||||
@ -220,7 +267,7 @@ class UiControlNetUnit:
|
||||
|
||||
|
||||
# Backward Compatible
|
||||
ControlNetUnit = UiControlNetUnit
|
||||
UiControlNetUnit = ControlNetUnit
|
||||
|
||||
|
||||
def to_base64_nparray(encoding: str):
|
||||
|
@ -190,44 +190,6 @@ def align_dim_latent(x: int) -> int:
|
||||
return (x // 8) * 8
|
||||
|
||||
|
||||
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
if isinstance(image, (tuple, list)):
|
||||
image = {'image': image[0], 'mask': image[1]}
|
||||
elif not isinstance(image, dict):
|
||||
image = {'image': image, 'mask': None}
|
||||
else: # type(image) is dict
|
||||
# copy to enable modifying the dict and prevent response serialization error
|
||||
image = dict(image)
|
||||
|
||||
if isinstance(image['image'], str):
|
||||
if os.path.exists(image['image']):
|
||||
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
|
||||
elif image['image']:
|
||||
image['image'] = external_code.to_base64_nparray(image['image'])
|
||||
else:
|
||||
image['image'] = None
|
||||
|
||||
# If there is no image, return image with None image and None mask
|
||||
if image['image'] is None:
|
||||
image['mask'] = None
|
||||
return image
|
||||
|
||||
if 'mask' not in image or image['mask'] is None:
|
||||
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
||||
elif isinstance(image['mask'], str):
|
||||
if os.path.exists(image['mask']):
|
||||
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
|
||||
elif image['mask']:
|
||||
image['mask'] = external_code.to_base64_nparray(image['mask'])
|
||||
else:
|
||||
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def prepare_mask(
|
||||
mask: Image.Image, p: processing.StableDiffusionProcessing
|
||||
) -> Image.Image:
|
||||
|
@ -12,7 +12,7 @@ import gradio as gr
|
||||
|
||||
from lib_controlnet import global_state, external_code
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
|
||||
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
|
||||
prepare_mask, judge_image_type
|
||||
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
|
Loading…
Reference in New Issue
Block a user