Update controlnet.py

This commit is contained in:
lllyasviel 2024-01-29 17:01:52 -08:00
parent 6933e0c135
commit fe2c1b8279

View File

@ -7,12 +7,14 @@ from modules.api.api import decode_base64_to_image
import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, prepare_mask
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask
from lib_controlnet.enums import StableDiffusionVersion
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch
@ -22,13 +24,12 @@ import functools
from PIL import Image
from modules_forge.shared import try_load_supported_control_model
# Gradio 3.32 bug fix
import tempfile
gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True)
global_state.update_controlnet_filenames()
@ -61,7 +62,8 @@ class ControlNetForForgeOfficial(scripts.Script):
model="None"
)
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]:
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[
ControlNetUiGroup, gr.State]:
group = ControlNetUiGroup(
is_img2img,
self.get_default_ui_unit(),
@ -168,7 +170,7 @@ class ControlNetForForgeOfficial(scripts.Script):
self,
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
) -> Tuple[np.ndarray, external_code.ResizeMode]:
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
@ -179,11 +181,13 @@ class ControlNetForForgeOfficial(scripts.Script):
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
)
if unit_has_multiple_images:
return [
@ -224,14 +228,14 @@ class ControlNetForForgeOfficial(scripts.Script):
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
elif (
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
@ -242,7 +246,7 @@ class ControlNetForForgeOfficial(scripts.Script):
assert a1111_i2i_resize_mode is not None
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module:
if a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
@ -263,10 +267,10 @@ class ControlNetForForgeOfficial(scripts.Script):
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
) -> np.ndarray:
"""
Crop ControlNet input image based on A1111 inpaint mask given.
@ -281,16 +285,16 @@ class ControlNetForForgeOfficial(scripts.Script):
logger.debug(f"is_upscale_script={is_upscale_script}")
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
# option is selected.
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = (
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.inpaint_full_res and
a1111_mask_image is not None
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.inpaint_full_res and
a1111_mask_image is not None
)
if (
'reference' not in unit.module
and is_only_masked_inpaint
and (is_upscale_script or unit.inpaint_crop_input_image)
'reference' not in unit.module
and is_only_masked_inpaint
and (is_upscale_script or unit.inpaint_crop_input_image)
):
logger.debug("Crop input image based on A1111 mask.")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
@ -363,7 +367,8 @@ class ControlNetForForgeOfficial(scripts.Script):
return
if not sd_version.is_compatible_with(cnet_sd_version):
raise Exception(f"ControlNet model {unit.model}({cnet_sd_version}) is not compatible with sd model({sd_version})")
raise Exception(
f"ControlNet model {unit.model}({cnet_sd_version}) is not compatible with sd model({sd_version})")
@staticmethod
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
@ -372,8 +377,8 @@ class ControlNetForForgeOfficial(scripts.Script):
w = align_dim_latent(p.width)
high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
if high_res_fix:
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
@ -398,8 +403,8 @@ class ControlNetForForgeOfficial(scripts.Script):
h, w, hr_y, hr_x = self.get_target_dimensions(p)
has_high_res_fix = (
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
)
input_image, resize_mode = self.choose_input_image(p, unit)
@ -509,9 +514,12 @@ def on_ui_settings():
shared.opts.add_option("control_net_models_path", shared.OptionInfo(
"", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section))
shared.opts.add_option("control_net_modules_path", shared.OptionInfo(
"", "Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)", section=section))
"",
"Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)",
section=section))
shared.opts.add_option("control_net_unit_count", shared.OptionInfo(
3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider,
{"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
@ -525,7 +533,8 @@ def on_ui_settings():
shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo(
False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo(
False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True}, section=section))
False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True},
section=section))
shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo(
False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_disable_photopea_edit", shared.OptionInfo(
@ -543,4 +552,4 @@ def on_ui_settings():
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
script_callbacks.on_before_reload(ControlNetUiGroup.reset)
script_callbacks.on_before_reload(ControlNetUiGroup.reset)