
Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

520 lines
24 KiB
Raw Normal View History

import os
from typing import Dict, Optional, Tuple, List, Union
2024-01-30 19:26:29 +00:00
import torch
import modules.scripts as scripts
2024-01-30 00:08:26 +00:00
from modules import shared, script_callbacks, processing, masking, images
from modules.api.api import decode_base64_to_image
import gradio as gr
2024-01-30 00:08:26 +00:00
from lib_controlnet import global_state, external_code
2024-01-30 01:01:52 +00:00
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
2024-01-30 21:15:57 +00:00
prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
2024-01-30 01:01:52 +00:00
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
from lib_controlnet.infotext import Infotext
2024-01-30 00:28:48 +00:00
from modules_forge.forge_util import HWC3, numpy_to_pytorch
import numpy as np
2024-01-29 23:19:47 +00:00
import functools
2024-01-30 00:08:26 +00:00
from PIL import Image
2024-01-29 23:19:47 +00:00
from modules_forge.shared import try_load_supported_control_model
2024-01-30 18:32:54 +00:00
from modules_forge.supported_controlnet import ControlModelPatcher
# Gradio 3.32 bug fix
import tempfile
2024-01-30 01:01:52 +00:00
gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True)
2024-01-29 23:19:47 +00:00
@functools.lru_cache("control_net_model_cache_size", 5))
def cached_controlnet_loader(filename):
return try_load_supported_control_model(filename)
2024-01-29 23:03:12 +00:00
class ControlNetCachedParameters:
def __init__(self):
2024-01-30 00:28:48 +00:00
self.preprocessor = None
self.model = None
self.control_cond = None
self.control_cond_for_hr_fix = None
2024-01-29 23:03:12 +00:00
2024-01-29 22:45:44 +00:00
class ControlNetForForgeOfficial(scripts.Script):
def title(self):
return "ControlNet"
def show(self, is_img2img):
return scripts.AlwaysVisible
2024-02-01 23:40:26 +00:00
def uigroup(self, tabname: str, is_img2img: bool, elem_id_tabname: str, photopea: Optional[Photopea]) -> Tuple[ControlNetUiGroup, gr.State]:
default_unit = UiControlNetUnit(enabled=False, module="None", model="None")
group = ControlNetUiGroup(is_img2img, default_unit, photopea)
return group, group.render(tabname, elem_id_tabname)
def ui(self, is_img2img):
infotext = Infotext()
ui_groups = []
controls = []
max_models ="control_net_unit_count", 3)
elem_id_tabname = ("img2img" if is_img2img else "txt2img") + "_controlnet"
with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet"):
photopea = Photopea() if not"controlnet_disable_photopea_edit", False) else None
if max_models > 1:
with gr.Tabs(elem_id=f"{elem_id_tabname}_tabs"):
for i in range(max_models):
with gr.Tab(f"ControlNet Unit {i}",
group, state = self.uigroup(f"ControlNet-{i}", is_img2img, elem_id_tabname, photopea)
with gr.Column():
group, state = self.uigroup(f"ControlNet", is_img2img, elem_id_tabname, photopea)
for i, ui_group in enumerate(ui_groups):
infotext.register_unit(i, ui_group)
if"control_net_sync_field_args", True):
self.infotext_fields = infotext.infotext_fields
self.paste_field_names = infotext.paste_field_names
return tuple(controls)
2024-01-29 22:47:40 +00:00
def get_enabled_units(self, p):
units = external_code.get_all_units_in_processing(p)
2024-02-01 23:40:26 +00:00
enabled_units = [x for x in units if x.enabled]
return enabled_units
def choose_input_image(
2024-01-29 23:29:56 +00:00
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
2024-01-30 01:01:52 +00:00
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
- unit.image: ControlNet tab input image.
- p.init_images: A1111 img2img tab input image.
- The input image in ndarray form.
- The resize mode.
2024-01-30 01:01:52 +00:00
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
2024-01-30 01:01:52 +00:00
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
if unit_has_multiple_images:
return [
for img in unit.image
for d in (image_dict_from_any(img),)
if d is not None
return image_dict_from_any(unit.image)
def decode_image(img) -> np.ndarray:
"""Need to check the image for API compatibility."""
if isinstance(img, str):
return np.asarray(decode_base64_to_image(image['image']))
assert isinstance(img, np.ndarray)
return img
# 4 input image sources.
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
2024-01-29 23:29:56 +00:00
if image is not None:
if isinstance(image, list):
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
input_image = HWC3(decode_image(image['image']))
if 'mask' in image and image['mask'] is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]
if 'inpaint' in unit.module:"using inpaint as input")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
elif (
2024-01-30 01:01:52 +00:00
not"controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
):"using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
elif a1111_image is not None:
input_image = HWC3(np.asarray(a1111_image))
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
assert a1111_i2i_resize_mode is not None
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
2024-01-30 01:01:52 +00:00
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module:
if a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2
assert a1111_mask.shape[0] == input_image.shape[0]
assert a1111_mask.shape[1] == input_image.shape[1]
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
input_image = np.concatenate([
input_image[:, :, 0:3],
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
], axis=2)
raise ValueError("controlnet is enabled but no input image is given")
assert isinstance(input_image, (np.ndarray, list))
return input_image, resize_mode
def try_crop_image_with_a1111_mask(
2024-01-30 01:01:52 +00:00
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
input_image: np.ndarray,
resize_mode: external_code.ResizeMode,
2024-01-30 02:50:36 +00:00
) -> np.ndarray:
Crop ControlNet input image based on A1111 inpaint mask given.
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
to crop tiles.
# Note: The method determining whether the active script is an upscale script is purely
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
# in the future.
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
# option is selected.
2024-01-30 01:01:52 +00:00
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = (
2024-01-30 01:01:52 +00:00
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.inpaint_full_res and
a1111_mask_image is not None
if (
2024-01-30 02:50:36 +00:00
2024-01-30 01:01:52 +00:00
and is_only_masked_inpaint
and (is_upscale_script or unit.inpaint_crop_input_image)
logger.debug("Crop input image based on A1111 mask.")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image]
mask = prepare_mask(a1111_mask_image, p)
crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)
input_image = [
images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
for i in input_image
input_image = [x.crop(crop_region) for x in input_image]
input_image = [
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
for x in input_image
input_image = [np.asarray(x)[:, :, 0] for x in input_image]
input_image = np.stack(input_image, axis=2)
return input_image
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
2024-01-29 23:26:14 +00:00
preprocessor = global_state.get_preprocessor(unit.module)
if unit.processor_res < 0:
unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512))
if unit.threshold_a < 0:
unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0))
if unit.threshold_b < 0:
unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0))
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
"""Returns (h, w, hr_h, hr_w)."""
h = align_dim_latent(p.height)
w = align_dim_latent(p.width)
high_res_fix = (
2024-01-30 01:01:52 +00:00
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
if high_res_fix:
if p.hr_resize_x == 0 and p.hr_resize_y == 0:
hr_y = int(p.height * p.hr_scale)
hr_x = int(p.width * p.hr_scale)
hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
hr_y = align_dim_latent(hr_y)
hr_x = align_dim_latent(hr_x)
hr_y = h
hr_x = w
return h, w, hr_y, hr_x
2024-02-02 01:27:20 +00:00
def get_input_data(self, p, unit, preprocessor):
mask = None
input_image, resize_mode = self.choose_input_image(p, unit)
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
return input_image, mask, resize_mode
2024-01-30 19:26:29 +00:00
2024-01-30 00:31:45 +00:00
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
2024-01-29 23:11:54 +00:00
h, w, hr_y, hr_x = self.get_target_dimensions(p)
2024-01-30 00:08:26 +00:00
2024-01-29 23:57:50 +00:00
has_high_res_fix = (
2024-01-30 01:01:52 +00:00
isinstance(p, StableDiffusionProcessingTxt2Img)
and getattr(p, 'enable_hr', False)
2024-01-29 23:57:50 +00:00
2024-01-30 02:50:36 +00:00
preprocessor = global_state.get_preprocessor(unit.module)
2024-02-02 01:27:20 +00:00
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
2024-01-29 23:54:00 +00:00
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
seed = set_numpy_seed(p)
logger.debug(f"Use numpy seed {seed}.")"Using preprocessor: {unit.module}")'preprocessor resolution = {unit.processor_res}')
2024-01-30 02:17:23 +00:00
preprocessor_output = preprocessor(
2024-01-29 23:54:00 +00:00
2024-01-30 21:36:38 +00:00
preprocessor_output_is_image = judge_image_type(preprocessor_output)
2024-01-29 23:54:00 +00:00
2024-01-30 02:17:23 +00:00
if preprocessor_output_is_image:
params.control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
2024-01-30 21:15:57 +00:00
2024-01-30 00:28:48 +00:00
params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1)
if has_high_res_fix:
2024-01-30 02:17:23 +00:00
params.control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
2024-01-30 21:15:57 +00:00
2024-01-30 00:28:48 +00:00
params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1)
params.control_cond_for_hr_fix = params.control_cond
2024-01-30 00:08:26 +00:00
2024-01-30 02:17:23 +00:00
params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output
2024-01-30 00:28:48 +00:00
2024-01-30 18:32:54 +00:00
if preprocessor.do_not_need_model:
model_filename = 'Not Needed'
params.model = ControlModelPatcher()
model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename)
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
2024-01-30 02:24:09 +00:00
params.preprocessor = preprocessor
2024-01-30 00:28:48 +00:00
2024-01-30 07:28:13 +00:00
params.preprocessor.process_after_running_preprocessors(process=p, params=params, **kwargs)
params.model.process_after_running_preprocessors(process=p, params=params, **kwargs)
2024-01-30 06:10:29 +00:00"Current ControlNet {type(params.model).__name__}: {model_filename}")
2024-01-30 02:23:21 +00:00
2024-01-30 19:26:29 +00:00
2024-01-30 02:23:21 +00:00
def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
is_hr_pass = getattr(p, 'is_hr_pass', False)
if is_hr_pass:
cond = params.control_cond_for_hr_fix
cond = params.control_cond
kwargs.update(dict(unit=unit, params=params))
2024-01-30 21:36:38 +00:00
# CN inpaint fix
2024-01-31 18:02:45 +00:00
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
2024-01-30 21:36:38 +00:00
kwargs['cond_before_inpaint_fix'] = cond.clone()
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
2024-01-30 02:23:21 +00:00
params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start)
params.model.end_percent = float(unit.guidance_end)
params.model.positive_advanced_weighting = None
params.model.negative_advanced_weighting = None
params.model.advanced_frame_weighting = None
params.model.advanced_sigma_weighting = None
2024-01-30 01:08:58 +00:00
2024-01-30 01:15:46 +00:00
soft_weighting = {
'input': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736,
0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625,
0.6806249999999999, 0.825],
'middle': [1.0],
'output': [0.09941396206337118, 0.12050177219802567, 0.14606275417942507, 0.17704576264172736,
0.26012233262329093, 0.3152997971191405, 0.3821815722656249, 0.4632503906249999, 0.561515625,
0.6806249999999999, 0.825]
zero_weighting = {
'input': [0.0] * 12,
'middle': [0.0],
'output': [0.0] * 12
2024-01-30 01:35:47 +00:00
if unit.control_mode == external_code.ControlMode.CONTROL.value:
2024-01-30 02:23:21 +00:00
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = zero_weighting.copy()
2024-01-30 01:15:46 +00:00
2024-01-30 02:23:21 +00:00
# high-ref fix pass always use softer injections
if is_hr_pass or unit.control_mode == external_code.ControlMode.PROMPT.value:
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()
2024-01-30 00:40:01 +00:00
2024-01-30 22:40:26 +00:00
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs)
params.model.process_before_every_sampling(p, cond, *args, **kwargs)
2024-01-30 00:31:45 +00:00
2024-01-30 00:47:56 +00:00"ControlNet Method {} patched.")
2024-01-30 20:53:21 +00:00
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
params: ControlNetCachedParameters,
*args, **kwargs):
2024-01-30 22:40:26 +00:00
params.preprocessor.process_after_every_sampling(p, params, *args, **kwargs)
params.model.process_after_every_sampling(p, params, *args, **kwargs)
2024-01-30 20:53:21 +00:00
2024-01-29 22:41:10 +00:00
def process(self, p, *args, **kwargs):
2024-01-29 23:03:12 +00:00
self.current_params = {}
for i, unit in enumerate(self.get_enabled_units(p)):
2024-01-29 23:26:14 +00:00
2024-01-29 23:03:12 +00:00
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params
2024-01-29 22:41:10 +00:00
def process_before_every_sampling(self, p, *args, **kwargs):
2024-01-30 00:47:56 +00:00
for i, unit in enumerate(self.get_enabled_units(p)):
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
2024-01-30 22:40:26 +00:00
def postprocess_batch_list(self, p, *args, **kwargs):
2024-01-30 20:53:21 +00:00
for i, unit in enumerate(self.get_enabled_units(p)):
self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
2024-01-30 08:22:33 +00:00
self.current_params = {}
2024-01-29 22:32:18 +00:00
def on_ui_settings():
section = ('control_net', "ControlNet")
shared.opts.add_option("control_net_detectedmap_dir", shared.OptionInfo(
"detected_maps", "Directory for detected maps auto saving", section=section))
shared.opts.add_option("control_net_models_path", shared.OptionInfo(
"", "Extra path to scan for ControlNet models (e.g. training output directory)", section=section))
shared.opts.add_option("control_net_modules_path", shared.OptionInfo(
2024-01-30 01:01:52 +00:00
"Path to directory containing annotator model directories (requires restart, overrides corresponding command line flag)",
shared.opts.add_option("control_net_unit_count", shared.OptionInfo(
2024-01-30 01:01:52 +00:00
3, "Multi-ControlNet: ControlNet unit number (requires restart)", gr.Slider,
{"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
2024-01-29 23:19:47 +00:00
5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo(
False, "Allow detectmap auto saving", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_allow_script_control", shared.OptionInfo(
False, "Allow other script to control this extension", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_sync_field_args", shared.OptionInfo(
True, "Paste ControlNet parameters in infotext", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_show_batch_images_in_ui", shared.OptionInfo(
False, "Show batch images in gradio gallery output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_increment_seed_during_batch", shared.OptionInfo(
2024-01-30 01:01:52 +00:00
False, "Increment seed after each controlnet batch iteration", gr.Checkbox, {"interactive": True},
shared.opts.add_option("controlnet_disable_openpose_edit", shared.OptionInfo(
False, "Disable openpose edit", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_disable_photopea_edit", shared.OptionInfo(
False, "Disable photopea edit", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_photopea_warning", shared.OptionInfo(
True, "Photopea popup warning", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_ignore_noninpaint_mask", shared.OptionInfo(
False, "Ignore mask on ControlNet input image if control type is not inpaint",
gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("controlnet_clip_detector_on_cpu", shared.OptionInfo(
False, "Load CLIP preprocessor model on CPU",
gr.Checkbox, {"interactive": True}, section=section))
2024-01-30 01:01:52 +00:00