ini mask support

This commit is contained in:
lllyasviel 2024-02-01 21:21:41 -08:00
parent 07977a1fec
commit 9f6ee2a688
9 changed files with 109 additions and 236 deletions

View File

@ -29,10 +29,11 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
self.image = None self.image = None
self.mask = None self.mask = None
self.latent = None self.latent = None
self.fill_mask_with_one_when_resize_and_fill = True
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3] self.image = cond
self.mask = kwargs['cond_before_inpaint_fix'][:, 3:] self.mask = mask
vae = process.sd_model.forge_objects.vae vae = process.sd_model.forge_objects.vae
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback. # This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
@ -58,7 +59,10 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
process.sd_model.forge_objects.unet = unet process.sd_model.forge_objects.unet = unet
self.latent = latent_image self.latent = latent_image
return
mixed_cond = cond * (1.0 - mask) - mask
return mixed_cond, None
def process_after_every_sampling(self, process, params, *args, **kwargs): def process_after_every_sampling(self, process, params, *args, **kwargs):
a1111_batch_result = args[0] a1111_batch_result = args[0]
@ -95,19 +99,21 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
self.setup_model_patcher(model) self.setup_model_patcher(model)
return return
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
H, W, C = input_image.shape H, W, C = input_image.shape
raw_color = input_image[:, :, 0:3].copy() raw_color = input_image.copy()
raw_mask = input_image[:, :, 3:4].copy() raw_mask = input_mask.copy()
input_image, remove_pad = resize_image_with_pad(input_image, 256) input_image, remove_pad = resize_image_with_pad(input_image, 256)
input_mask, remove_pad = resize_image_with_pad(input_mask, 256)
input_mask = input_mask[..., :1]
self.load_model() self.load_model()
self.move_all_model_patchers_to_gpu() self.move_all_model_patchers_to_gpu()
color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0 color = np.ascontiguousarray(input_image).astype(np.float32) / 255.0
mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0 mask = np.ascontiguousarray(input_mask).astype(np.float32) / 255.0
with torch.no_grad(): with torch.no_grad():
color = self.send_tensor_to_model_device(torch.from_numpy(color)) color = self.send_tensor_to_model_device(torch.from_numpy(color))
mask = self.send_tensor_to_model_device(torch.from_numpy(mask)) mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
@ -128,15 +134,14 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha) fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
fin_color = fin_color.clip(0, 255).astype(np.uint8) fin_color = fin_color.clip(0, 255).astype(np.uint8)
result = np.concatenate([fin_color, raw_mask], axis=2) return fin_color
return result
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
super().process_before_every_sampling(process, cond, *args, **kwargs) cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max
original_noise = kwargs['noise'] original_noise = kwargs['noise']
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise) process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
return return cond, mask
add_supported_preprocessor(PreprocessorInpaint()) add_supported_preprocessor(PreprocessorInpaint())

View File

@ -38,9 +38,9 @@ class PreprocessorRecolor(Preprocessor):
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB) result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
return result return result
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.current_cond = cond self.current_cond = cond
return return cond, mask
def process_after_every_sampling(self, process, params, *args, **kwargs): def process_after_every_sampling(self, process, params, *args, **kwargs):
a1111_batch_result = args[0] a1111_batch_result = args[0]

View File

@ -47,7 +47,7 @@ class PreprocessorReference(Preprocessor):
self.recorded_attn1 = {} self.recorded_attn1 = {}
self.recorded_h = {} self.recorded_h = {}
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unit = kwargs['unit'] unit = kwargs['unit']
weight = float(unit.weight) weight = float(unit.weight)
style_fidelity = float(unit.threshold_a) style_fidelity = float(unit.threshold_a)
@ -190,7 +190,7 @@ class PreprocessorReference(Preprocessor):
process.sd_model.forge_objects.unet = unet process.sd_model.forge_objects.unet = unet
return return cond, mask
add_supported_preprocessor(PreprocessorReference( add_supported_preprocessor(PreprocessorReference(

View File

@ -63,7 +63,7 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
self.slider_1 = PreprocessorParameter( self.slider_1 = PreprocessorParameter(
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True) label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unit = kwargs['unit'] unit = kwargs['unit']
weight = float(unit.weight) weight = float(unit.weight)
@ -84,7 +84,8 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True) unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True)
process.sd_model.forge_objects.unet = unet process.sd_model.forge_objects.unet = unet
return
return cond, mask
add_supported_preprocessor(PreprocessorClipVisionForRevision( add_supported_preprocessor(PreprocessorClipVisionForRevision(

View File

@ -37,7 +37,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
self.variation = 8 self.variation = 8
self.sharpness = None self.sharpness = None
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.variation = int(kwargs['unit'].threshold_a) self.variation = int(kwargs['unit'].threshold_a)
latent = self.register_latent(process, cond) latent = self.register_latent(process, cond)
@ -77,7 +77,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
process.sd_model.forge_objects.unet = unet process.sd_model.forge_objects.unet = unet
return return cond, mask
class PreprocessorTileColorFixSharp(PreprocessorTileColorFix): class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
@ -86,10 +86,9 @@ class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
self.name = 'tile_colorfix+sharp' self.name = 'tile_colorfix+sharp'
self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True) self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.sharpness = float(kwargs['unit'].threshold_b) self.sharpness = float(kwargs['unit'].threshold_b)
super().process_before_every_sampling(process, cond, *args, **kwargs) return super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
return
add_supported_preprocessor(PreprocessorTile()) add_supported_preprocessor(PreprocessorTile())

View File

@ -215,7 +215,6 @@ class ControlNetUiGroup(object):
self.image_upload_panel = None self.image_upload_panel = None
self.save_detected_map = None self.save_detected_map = None
self.input_mode = gr.State(InputMode.SIMPLE) self.input_mode = gr.State(InputMode.SIMPLE)
self.inpaint_crop_input_image = None
self.hr_option = None self.hr_option = None
self.batch_image_dir_state = None self.batch_image_dir_state = None
self.output_dir_state = None self.output_dir_state = None
@ -443,15 +442,6 @@ class ControlNetUiGroup(object):
else: else:
self.upload_independent_img_in_img2img = None self.upload_independent_img_in_img2img = None
# Note: The checkbox needs to exist for both img2img and txt2img as infotext
# needs the checkbox value.
self.inpaint_crop_input_image = gr.Checkbox(
label="Crop input image based on A1111 mask",
value=False,
elem_classes=["cnet-crop-input-image"],
visible=False,
)
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = gr.Radio( self.type_filter = gr.Radio(
global_state.get_all_preprocessor_tags(), global_state.get_all_preprocessor_tags(),
@ -805,18 +795,10 @@ class ControlNetUiGroup(object):
) )
img = HWC3(image["image"]) img = HWC3(image["image"])
has_mask = not ( mask = HWC3(image["mask"])
(image["mask"][:, :, 0] <= 5).all()
or (image["mask"][:, :, 0] >= 250).all() if not (mask > 5).any():
) mask = None
if "inpaint" in module:
color = HWC3(image["image"])
alpha = image["mask"][:, :, 0:1]
img = np.concatenate([color, alpha], axis=2)
elif has_mask and not shared.opts.data.get(
"controlnet_ignore_noninpaint_mask", False
):
img = HWC3(image["mask"][:, :, 0])
preprocessor = global_state.get_preprocessor(module) preprocessor = global_state.get_preprocessor(module)
@ -854,6 +836,7 @@ class ControlNetUiGroup(object):
resolution=pres, resolution=pres,
slider_1=pthr_a, slider_1=pthr_a,
slider_2=pthr_b, slider_2=pthr_b,
input_mask=mask,
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False), low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
json_pose_callback=json_acceptor.accept json_pose_callback=json_acceptor.accept
if is_openpose(module) if is_openpose(module)
@ -976,52 +959,9 @@ class ControlNetUiGroup(object):
) )
def register_shift_crop_input_image(self): def register_shift_crop_input_image(self):
# A1111 < 1.7.0 compatibility. return
if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs):
self.inpaint_crop_input_image.visible = True
self.inpaint_crop_input_image.value = True
return
is_inpaint_tab = gr.State(False)
def shift_crop_input_image(is_inpaint: bool, inpaint_area: int):
# Note: inpaint_area (0: Whole picture, 1: Only masked)
# By default set value to True, as most preprocessors need cropped result.
return gr.update(value=True, visible=is_inpaint and inpaint_area == 1)
gradio_kwargs = dict(
fn=shift_crop_input_image,
inputs=[
is_inpaint_tab,
ControlNetUiGroup.a1111_context.img2img_inpaint_area,
],
outputs=[self.inpaint_crop_input_image],
show_progress=False,
)
for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs:
elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then(
**gradio_kwargs
)
for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs:
elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then(
**gradio_kwargs
)
ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs)
def register_shift_hr_options(self): def register_shift_hr_options(self):
# # A1111 version < 1.6.0.
# if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
# return
#
# ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
# fn=lambda checked: gr.update(visible=checked),
# inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
# outputs=[self.hr_option],
# show_progress=False,
# )
return return
def register_shift_upload_mask(self): def register_shift_upload_mask(self):

View File

@ -317,11 +317,6 @@ def high_quality_resize(x, size):
# Written by lvmin # Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
if x.shape[0] != size[1] or x.shape[1] != size[0]: if x.shape[0] != size[1] or x.shape[1] != size[0]:
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1]) new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1]) new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
@ -346,8 +341,6 @@ def high_quality_resize(x, size):
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation) y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary: if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8) y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
@ -361,15 +354,10 @@ def high_quality_resize(x, size):
else: else:
y = x y = x
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y return y
def crop_and_resize_image(detected_map, resize_mode, h, w): def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=False):
if resize_mode == external_code.ResizeMode.RESIZE: if resize_mode == external_code.ResizeMode.RESIZE:
detected_map = high_quality_resize(detected_map, (w, h)) detected_map = high_quality_resize(detected_map, (w, h))
detected_map = safe_numpy(detected_map) detected_map = safe_numpy(detected_map)
@ -387,9 +375,8 @@ def crop_and_resize_image(detected_map, resize_mode, h, w):
k = min(k0, k1) k = min(k0, k1)
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0) borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype) high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
if len(high_quality_border_color) == 4: if fill_border_with_255:
# Inpaint hijack high_quality_border_color = np.zeros_like(high_quality_border_color) + 255
high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k))) detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = detected_map.shape new_h, new_w, _ = detected_map.shape

View File

@ -1,6 +1,7 @@
import os import os
from typing import Dict, Optional, Tuple, List, Union from typing import Dict, Optional, Tuple, List, Union
import cv2
import torch import torch
import modules.scripts as scripts import modules.scripts as scripts
@ -46,6 +47,8 @@ class ControlNetCachedParameters:
self.model = None self.model = None
self.control_cond = None self.control_cond = None
self.control_cond_for_hr_fix = None self.control_cond_for_hr_fix = None
self.control_mask = None
self.control_mask_for_hr_fix = None
class ControlNetForForgeOfficial(scripts.Script): class ControlNetForForgeOfficial(scripts.Script):
@ -95,105 +98,6 @@ class ControlNetForForgeOfficial(scripts.Script):
enabled_units = [x for x in units if x.enabled] enabled_units = [x for x in units if x.enabled]
return enabled_units return enabled_units
def choose_input_image(
self,
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
- unit.image: ControlNet tab input image.
- p.init_images: A1111 img2img tab input image.
Returns:
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
)
if unit_has_multiple_images:
return [
d
for img in unit.image
for d in (image_dict_from_any(img),)
if d is not None
]
return image_dict_from_any(unit.image)
def decode_image(img) -> np.ndarray:
"""Need to check the image for API compatibility."""
if isinstance(img, str):
return np.asarray(decode_base64_to_image(image['image']))
else:
assert isinstance(img, np.ndarray)
return img
# 4 input image sources.
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if image is not None:
if isinstance(image, list):
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
else:
input_image = HWC3(decode_image(image['image']))
if 'mask' in image and image['mask'] is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]
if 'inpaint' in unit.module:
logger.info("using inpaint as input")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
elif (
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
elif a1111_image is not None:
input_image = HWC3(np.asarray(a1111_image))
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
assert a1111_i2i_resize_mode is not None
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module:
if a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2
assert a1111_mask.shape[0] == input_image.shape[0]
assert a1111_mask.shape[1] == input_image.shape[1]
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
else:
input_image = np.concatenate([
input_image[:, :, 0:3],
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
], axis=2)
else:
raise ValueError("controlnet is enabled but no input image is given")
assert isinstance(input_image, (np.ndarray, list))
return input_image, resize_mode
@staticmethod @staticmethod
def try_crop_image_with_a1111_mask( def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
@ -202,19 +106,6 @@ class ControlNetForForgeOfficial(scripts.Script):
resize_mode: external_code.ResizeMode, resize_mode: external_code.ResizeMode,
preprocessor preprocessor
) -> np.ndarray: ) -> np.ndarray:
"""
Crop ControlNet input image based on A1111 inpaint mask given.
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
to crop tiles.
"""
# Note: The method determining whether the active script is an upscale script is purely
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
# in the future.
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
logger.debug(f"is_upscale_script={is_upscale_script}")
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
# option is selected.
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None) a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = ( is_only_masked_inpaint = (
issubclass(type(p), StableDiffusionProcessingImg2Img) and issubclass(type(p), StableDiffusionProcessingImg2Img) and
@ -224,9 +115,8 @@ class ControlNetForForgeOfficial(scripts.Script):
if ( if (
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
and is_only_masked_inpaint and is_only_masked_inpaint
and (is_upscale_script or unit.inpaint_crop_input_image)
): ):
logger.debug("Crop input image based on A1111 mask.") logger.info("Crop input image based on A1111 mask.")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])] input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image] input_image = [Image.fromarray(x) for x in input_image]
@ -251,12 +141,47 @@ class ControlNetForForgeOfficial(scripts.Script):
return input_image return input_image
def get_input_data(self, p, unit, preprocessor): def get_input_data(self, p, unit, preprocessor):
mask = None a1111_i2i_image = getattr(p, "init_images", [None])[0]
input_image, resize_mode = self.choose_input_image(p, unit) a1111_i2i_mask = getattr(p, "image_mask", None)
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor) using_a1111_data = False
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
return input_image, mask, resize_mode resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if unit.use_preview_as_input and unit.generated_image is not None:
image = unit.generated_image
elif unit.image is None:
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
image = HWC3(np.asarray(a1111_i2i_image))
using_a1111_data = True
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
image = unit.image['mask']
else:
image = unit.image['image']
if not isinstance(image, np.ndarray):
raise ValueError("controlnet is enabled but no input image is given")
image = HWC3(image)
if using_a1111_data:
mask = HWC3(np.asarray(a1111_i2i_mask))
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
mask = unit.mask_image['image']
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
mask = unit.mask_image['mask']
elif unit.image is not None and (unit.image['mask'] > 5).any():
mask = unit.image['mask']
else:
mask = None
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
if mask is not None:
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
return image, mask, resize_mode
@staticmethod @staticmethod
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]: def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
@ -296,9 +221,13 @@ class ControlNetForForgeOfficial(scripts.Script):
and getattr(p, 'enable_hr', False) and getattr(p, 'enable_hr', False)
) )
if unit.use_preview_as_input:
unit.module = 'None'
preprocessor = global_state.get_preprocessor(unit.module) preprocessor = global_state.get_preprocessor(unit.module)
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor) input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
# p.extra_result_images.append(input_image)
if unit.pixel_perfect: if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution( unit.processor_res = external_code.pixel_perfect_resolution(
@ -339,10 +268,24 @@ class ControlNetForForgeOfficial(scripts.Script):
params.control_cond_for_hr_fix = preprocessor_output params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image) p.extra_result_images.append(input_image)
if input_mask is not None:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(params.control_mask)
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
if has_high_res_fix:
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(params.control_mask_for_hr_fix)
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
else:
params.control_mask_for_hr_fix = params.control_mask
if preprocessor.do_not_need_model: if preprocessor.do_not_need_model:
model_filename = 'Not Needed' model_filename = 'Not Needed'
params.model = ControlModelPatcher() params.model = ControlModelPatcher()
else: else:
assert unit.model != 'None', 'You have not selected any control model!'
model_filename = global_state.get_controlnet_filename(unit.model) model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename) params.model = cached_controlnet_loader(model_filename)
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}") assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
@ -366,16 +309,13 @@ class ControlNetForForgeOfficial(scripts.Script):
if is_hr_pass: if is_hr_pass:
cond = params.control_cond_for_hr_fix cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix
else: else:
cond = params.control_cond cond = params.control_cond
mask = params.control_mask
kwargs.update(dict(unit=unit, params=params)) kwargs.update(dict(unit=unit, params=params))
# CN inpaint fix
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
kwargs['cond_before_inpaint_fix'] = cond.clone()
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
params.model.strength = float(unit.weight) params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start) params.model.start_percent = float(unit.guidance_start)
params.model.end_percent = float(unit.guidance_end) params.model.end_percent = float(unit.guidance_end)
@ -411,8 +351,8 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model.positive_advanced_weighting = soft_weighting.copy() params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = soft_weighting.copy()
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs) cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
params.model.process_before_every_sampling(p, cond, *args, **kwargs) params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
logger.info(f"ControlNet Method {params.preprocessor.name} patched.") logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return return

View File

@ -31,6 +31,7 @@ class Preprocessor:
self.do_not_need_model = False self.do_not_need_model = False
self.sorting_priority = 0 # higher goes to top in the list self.sorting_priority = 0 # higher goes to top in the list
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
self.fill_mask_with_one_when_resize_and_fill = False
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs): def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
if load_device is None: if load_device is None:
@ -59,8 +60,8 @@ class Preprocessor:
def process_after_running_preprocessors(self, process, params, *args, **kwargs): def process_after_running_preprocessors(self, process, params, *args, **kwargs):
return return
def process_before_every_sampling(self, process, cond, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
return return cond, mask
def process_after_every_sampling(self, process, params, *args, **kwargs): def process_after_every_sampling(self, process, params, *args, **kwargs):
return return