ini mask support
This commit is contained in:
parent
07977a1fec
commit
9f6ee2a688
@ -29,10 +29,11 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
|||||||
self.image = None
|
self.image = None
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.latent = None
|
self.latent = None
|
||||||
|
self.fill_mask_with_one_when_resize_and_fill = True
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3]
|
self.image = cond
|
||||||
self.mask = kwargs['cond_before_inpaint_fix'][:, 3:]
|
self.mask = mask
|
||||||
|
|
||||||
vae = process.sd_model.forge_objects.vae
|
vae = process.sd_model.forge_objects.vae
|
||||||
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
|
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
|
||||||
@ -58,7 +59,10 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
|||||||
process.sd_model.forge_objects.unet = unet
|
process.sd_model.forge_objects.unet = unet
|
||||||
|
|
||||||
self.latent = latent_image
|
self.latent = latent_image
|
||||||
return
|
|
||||||
|
mixed_cond = cond * (1.0 - mask) - mask
|
||||||
|
|
||||||
|
return mixed_cond, None
|
||||||
|
|
||||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||||
a1111_batch_result = args[0]
|
a1111_batch_result = args[0]
|
||||||
@ -95,19 +99,21 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
|||||||
self.setup_model_patcher(model)
|
self.setup_model_patcher(model)
|
||||||
return
|
return
|
||||||
|
|
||||||
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
|
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
|
||||||
H, W, C = input_image.shape
|
H, W, C = input_image.shape
|
||||||
raw_color = input_image[:, :, 0:3].copy()
|
raw_color = input_image.copy()
|
||||||
raw_mask = input_image[:, :, 3:4].copy()
|
raw_mask = input_mask.copy()
|
||||||
|
|
||||||
input_image, remove_pad = resize_image_with_pad(input_image, 256)
|
input_image, remove_pad = resize_image_with_pad(input_image, 256)
|
||||||
|
input_mask, remove_pad = resize_image_with_pad(input_mask, 256)
|
||||||
|
input_mask = input_mask[..., :1]
|
||||||
|
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
self.move_all_model_patchers_to_gpu()
|
self.move_all_model_patchers_to_gpu()
|
||||||
|
|
||||||
color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
|
color = np.ascontiguousarray(input_image).astype(np.float32) / 255.0
|
||||||
mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
|
mask = np.ascontiguousarray(input_mask).astype(np.float32) / 255.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
color = self.send_tensor_to_model_device(torch.from_numpy(color))
|
color = self.send_tensor_to_model_device(torch.from_numpy(color))
|
||||||
mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
|
mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
|
||||||
@ -128,15 +134,14 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
|||||||
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
|
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
|
||||||
fin_color = fin_color.clip(0, 255).astype(np.uint8)
|
fin_color = fin_color.clip(0, 255).astype(np.uint8)
|
||||||
|
|
||||||
result = np.concatenate([fin_color, raw_mask], axis=2)
|
return fin_color
|
||||||
return result
|
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
super().process_before_every_sampling(process, cond, *args, **kwargs)
|
cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
|
||||||
sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max
|
sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max
|
||||||
original_noise = kwargs['noise']
|
original_noise = kwargs['noise']
|
||||||
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
|
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
|
||||||
return
|
return cond, mask
|
||||||
|
|
||||||
|
|
||||||
add_supported_preprocessor(PreprocessorInpaint())
|
add_supported_preprocessor(PreprocessorInpaint())
|
||||||
|
@ -38,9 +38,9 @@ class PreprocessorRecolor(Preprocessor):
|
|||||||
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
|
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
self.current_cond = cond
|
self.current_cond = cond
|
||||||
return
|
return cond, mask
|
||||||
|
|
||||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||||
a1111_batch_result = args[0]
|
a1111_batch_result = args[0]
|
||||||
|
@ -47,7 +47,7 @@ class PreprocessorReference(Preprocessor):
|
|||||||
self.recorded_attn1 = {}
|
self.recorded_attn1 = {}
|
||||||
self.recorded_h = {}
|
self.recorded_h = {}
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
unit = kwargs['unit']
|
unit = kwargs['unit']
|
||||||
weight = float(unit.weight)
|
weight = float(unit.weight)
|
||||||
style_fidelity = float(unit.threshold_a)
|
style_fidelity = float(unit.threshold_a)
|
||||||
@ -190,7 +190,7 @@ class PreprocessorReference(Preprocessor):
|
|||||||
|
|
||||||
process.sd_model.forge_objects.unet = unet
|
process.sd_model.forge_objects.unet = unet
|
||||||
|
|
||||||
return
|
return cond, mask
|
||||||
|
|
||||||
|
|
||||||
add_supported_preprocessor(PreprocessorReference(
|
add_supported_preprocessor(PreprocessorReference(
|
||||||
|
@ -63,7 +63,7 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
|
|||||||
self.slider_1 = PreprocessorParameter(
|
self.slider_1 = PreprocessorParameter(
|
||||||
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
|
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
unit = kwargs['unit']
|
unit = kwargs['unit']
|
||||||
|
|
||||||
weight = float(unit.weight)
|
weight = float(unit.weight)
|
||||||
@ -84,7 +84,8 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
|
|||||||
unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True)
|
unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True)
|
||||||
|
|
||||||
process.sd_model.forge_objects.unet = unet
|
process.sd_model.forge_objects.unet = unet
|
||||||
return
|
|
||||||
|
return cond, mask
|
||||||
|
|
||||||
|
|
||||||
add_supported_preprocessor(PreprocessorClipVisionForRevision(
|
add_supported_preprocessor(PreprocessorClipVisionForRevision(
|
||||||
|
@ -37,7 +37,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
|
|||||||
self.variation = 8
|
self.variation = 8
|
||||||
self.sharpness = None
|
self.sharpness = None
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
self.variation = int(kwargs['unit'].threshold_a)
|
self.variation = int(kwargs['unit'].threshold_a)
|
||||||
|
|
||||||
latent = self.register_latent(process, cond)
|
latent = self.register_latent(process, cond)
|
||||||
@ -77,7 +77,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
|
|||||||
|
|
||||||
process.sd_model.forge_objects.unet = unet
|
process.sd_model.forge_objects.unet = unet
|
||||||
|
|
||||||
return
|
return cond, mask
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
|
class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
|
||||||
@ -86,10 +86,9 @@ class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
|
|||||||
self.name = 'tile_colorfix+sharp'
|
self.name = 'tile_colorfix+sharp'
|
||||||
self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)
|
self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
self.sharpness = float(kwargs['unit'].threshold_b)
|
self.sharpness = float(kwargs['unit'].threshold_b)
|
||||||
super().process_before_every_sampling(process, cond, *args, **kwargs)
|
return super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
add_supported_preprocessor(PreprocessorTile())
|
add_supported_preprocessor(PreprocessorTile())
|
||||||
|
@ -215,7 +215,6 @@ class ControlNetUiGroup(object):
|
|||||||
self.image_upload_panel = None
|
self.image_upload_panel = None
|
||||||
self.save_detected_map = None
|
self.save_detected_map = None
|
||||||
self.input_mode = gr.State(InputMode.SIMPLE)
|
self.input_mode = gr.State(InputMode.SIMPLE)
|
||||||
self.inpaint_crop_input_image = None
|
|
||||||
self.hr_option = None
|
self.hr_option = None
|
||||||
self.batch_image_dir_state = None
|
self.batch_image_dir_state = None
|
||||||
self.output_dir_state = None
|
self.output_dir_state = None
|
||||||
@ -443,15 +442,6 @@ class ControlNetUiGroup(object):
|
|||||||
else:
|
else:
|
||||||
self.upload_independent_img_in_img2img = None
|
self.upload_independent_img_in_img2img = None
|
||||||
|
|
||||||
# Note: The checkbox needs to exist for both img2img and txt2img as infotext
|
|
||||||
# needs the checkbox value.
|
|
||||||
self.inpaint_crop_input_image = gr.Checkbox(
|
|
||||||
label="Crop input image based on A1111 mask",
|
|
||||||
value=False,
|
|
||||||
elem_classes=["cnet-crop-input-image"],
|
|
||||||
visible=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
|
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
|
||||||
self.type_filter = gr.Radio(
|
self.type_filter = gr.Radio(
|
||||||
global_state.get_all_preprocessor_tags(),
|
global_state.get_all_preprocessor_tags(),
|
||||||
@ -805,18 +795,10 @@ class ControlNetUiGroup(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
img = HWC3(image["image"])
|
img = HWC3(image["image"])
|
||||||
has_mask = not (
|
mask = HWC3(image["mask"])
|
||||||
(image["mask"][:, :, 0] <= 5).all()
|
|
||||||
or (image["mask"][:, :, 0] >= 250).all()
|
if not (mask > 5).any():
|
||||||
)
|
mask = None
|
||||||
if "inpaint" in module:
|
|
||||||
color = HWC3(image["image"])
|
|
||||||
alpha = image["mask"][:, :, 0:1]
|
|
||||||
img = np.concatenate([color, alpha], axis=2)
|
|
||||||
elif has_mask and not shared.opts.data.get(
|
|
||||||
"controlnet_ignore_noninpaint_mask", False
|
|
||||||
):
|
|
||||||
img = HWC3(image["mask"][:, :, 0])
|
|
||||||
|
|
||||||
preprocessor = global_state.get_preprocessor(module)
|
preprocessor = global_state.get_preprocessor(module)
|
||||||
|
|
||||||
@ -854,6 +836,7 @@ class ControlNetUiGroup(object):
|
|||||||
resolution=pres,
|
resolution=pres,
|
||||||
slider_1=pthr_a,
|
slider_1=pthr_a,
|
||||||
slider_2=pthr_b,
|
slider_2=pthr_b,
|
||||||
|
input_mask=mask,
|
||||||
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
|
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
|
||||||
json_pose_callback=json_acceptor.accept
|
json_pose_callback=json_acceptor.accept
|
||||||
if is_openpose(module)
|
if is_openpose(module)
|
||||||
@ -976,52 +959,9 @@ class ControlNetUiGroup(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def register_shift_crop_input_image(self):
|
def register_shift_crop_input_image(self):
|
||||||
# A1111 < 1.7.0 compatibility.
|
return
|
||||||
if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs):
|
|
||||||
self.inpaint_crop_input_image.visible = True
|
|
||||||
self.inpaint_crop_input_image.value = True
|
|
||||||
return
|
|
||||||
|
|
||||||
is_inpaint_tab = gr.State(False)
|
|
||||||
|
|
||||||
def shift_crop_input_image(is_inpaint: bool, inpaint_area: int):
|
|
||||||
# Note: inpaint_area (0: Whole picture, 1: Only masked)
|
|
||||||
# By default set value to True, as most preprocessors need cropped result.
|
|
||||||
return gr.update(value=True, visible=is_inpaint and inpaint_area == 1)
|
|
||||||
|
|
||||||
gradio_kwargs = dict(
|
|
||||||
fn=shift_crop_input_image,
|
|
||||||
inputs=[
|
|
||||||
is_inpaint_tab,
|
|
||||||
ControlNetUiGroup.a1111_context.img2img_inpaint_area,
|
|
||||||
],
|
|
||||||
outputs=[self.inpaint_crop_input_image],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs:
|
|
||||||
elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then(
|
|
||||||
**gradio_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs:
|
|
||||||
elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then(
|
|
||||||
**gradio_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs)
|
|
||||||
|
|
||||||
def register_shift_hr_options(self):
|
def register_shift_hr_options(self):
|
||||||
# # A1111 version < 1.6.0.
|
|
||||||
# if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
|
|
||||||
# return
|
|
||||||
#
|
|
||||||
# ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
|
|
||||||
# fn=lambda checked: gr.update(visible=checked),
|
|
||||||
# inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
|
|
||||||
# outputs=[self.hr_option],
|
|
||||||
# show_progress=False,
|
|
||||||
# )
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def register_shift_upload_mask(self):
|
def register_shift_upload_mask(self):
|
||||||
|
@ -317,11 +317,6 @@ def high_quality_resize(x, size):
|
|||||||
# Written by lvmin
|
# Written by lvmin
|
||||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||||
|
|
||||||
inpaint_mask = None
|
|
||||||
if x.ndim == 3 and x.shape[2] == 4:
|
|
||||||
inpaint_mask = x[:, :, 3]
|
|
||||||
x = x[:, :, 0:3]
|
|
||||||
|
|
||||||
if x.shape[0] != size[1] or x.shape[1] != size[0]:
|
if x.shape[0] != size[1] or x.shape[1] != size[0]:
|
||||||
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||||
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||||
@ -346,8 +341,6 @@ def high_quality_resize(x, size):
|
|||||||
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||||
|
|
||||||
y = cv2.resize(x, size, interpolation=interpolation)
|
y = cv2.resize(x, size, interpolation=interpolation)
|
||||||
if inpaint_mask is not None:
|
|
||||||
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
|
||||||
|
|
||||||
if is_binary:
|
if is_binary:
|
||||||
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||||
@ -361,15 +354,10 @@ def high_quality_resize(x, size):
|
|||||||
else:
|
else:
|
||||||
y = x
|
y = x
|
||||||
|
|
||||||
if inpaint_mask is not None:
|
|
||||||
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
|
||||||
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
|
||||||
y = np.concatenate([y, inpaint_mask], axis=2)
|
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def crop_and_resize_image(detected_map, resize_mode, h, w):
|
def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=False):
|
||||||
if resize_mode == external_code.ResizeMode.RESIZE:
|
if resize_mode == external_code.ResizeMode.RESIZE:
|
||||||
detected_map = high_quality_resize(detected_map, (w, h))
|
detected_map = high_quality_resize(detected_map, (w, h))
|
||||||
detected_map = safe_numpy(detected_map)
|
detected_map = safe_numpy(detected_map)
|
||||||
@ -387,9 +375,8 @@ def crop_and_resize_image(detected_map, resize_mode, h, w):
|
|||||||
k = min(k0, k1)
|
k = min(k0, k1)
|
||||||
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
|
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
|
||||||
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
|
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
|
||||||
if len(high_quality_border_color) == 4:
|
if fill_border_with_255:
|
||||||
# Inpaint hijack
|
high_quality_border_color = np.zeros_like(high_quality_border_color) + 255
|
||||||
high_quality_border_color[3] = 255
|
|
||||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||||
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
new_h, new_w, _ = detected_map.shape
|
new_h, new_w, _ = detected_map.shape
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Tuple, List, Union
|
from typing import Dict, Optional, Tuple, List, Union
|
||||||
|
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
@ -46,6 +47,8 @@ class ControlNetCachedParameters:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.control_cond = None
|
self.control_cond = None
|
||||||
self.control_cond_for_hr_fix = None
|
self.control_cond_for_hr_fix = None
|
||||||
|
self.control_mask = None
|
||||||
|
self.control_mask_for_hr_fix = None
|
||||||
|
|
||||||
|
|
||||||
class ControlNetForForgeOfficial(scripts.Script):
|
class ControlNetForForgeOfficial(scripts.Script):
|
||||||
@ -95,105 +98,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
enabled_units = [x for x in units if x.enabled]
|
enabled_units = [x for x in units if x.enabled]
|
||||||
return enabled_units
|
return enabled_units
|
||||||
|
|
||||||
def choose_input_image(
|
|
||||||
self,
|
|
||||||
p: processing.StableDiffusionProcessing,
|
|
||||||
unit: external_code.ControlNetUnit,
|
|
||||||
) -> Tuple[np.ndarray, external_code.ResizeMode]:
|
|
||||||
""" Choose input image from following sources with descending priority:
|
|
||||||
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
|
|
||||||
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
|
|
||||||
- unit.image: ControlNet tab input image.
|
|
||||||
- p.init_images: A1111 img2img tab input image.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The input image in ndarray form.
|
|
||||||
- The resize mode.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
|
|
||||||
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
|
|
||||||
unit_has_multiple_images = (
|
|
||||||
isinstance(unit.image, list) and
|
|
||||||
len(unit.image) > 0 and
|
|
||||||
"image" in unit.image[0]
|
|
||||||
)
|
|
||||||
if unit_has_multiple_images:
|
|
||||||
return [
|
|
||||||
d
|
|
||||||
for img in unit.image
|
|
||||||
for d in (image_dict_from_any(img),)
|
|
||||||
if d is not None
|
|
||||||
]
|
|
||||||
return image_dict_from_any(unit.image)
|
|
||||||
|
|
||||||
def decode_image(img) -> np.ndarray:
|
|
||||||
"""Need to check the image for API compatibility."""
|
|
||||||
if isinstance(img, str):
|
|
||||||
return np.asarray(decode_base64_to_image(image['image']))
|
|
||||||
else:
|
|
||||||
assert isinstance(img, np.ndarray)
|
|
||||||
return img
|
|
||||||
|
|
||||||
# 4 input image sources.
|
|
||||||
image = parse_unit_image(unit)
|
|
||||||
a1111_image = getattr(p, "init_images", [None])[0]
|
|
||||||
|
|
||||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
if isinstance(image, list):
|
|
||||||
# Add mask logic if later there is a processor that accepts mask
|
|
||||||
# on multiple inputs.
|
|
||||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
|
||||||
else:
|
|
||||||
input_image = HWC3(decode_image(image['image']))
|
|
||||||
if 'mask' in image and image['mask'] is not None:
|
|
||||||
while len(image['mask'].shape) < 3:
|
|
||||||
image['mask'] = image['mask'][..., np.newaxis]
|
|
||||||
if 'inpaint' in unit.module:
|
|
||||||
logger.info("using inpaint as input")
|
|
||||||
color = HWC3(image['image'])
|
|
||||||
alpha = image['mask'][:, :, 0:1]
|
|
||||||
input_image = np.concatenate([color, alpha], axis=2)
|
|
||||||
elif (
|
|
||||||
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
|
|
||||||
# There is wield gradio issue that would produce mask that is
|
|
||||||
# not pure color when no scribble is made on canvas.
|
|
||||||
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
|
|
||||||
not (
|
|
||||||
(image['mask'][:, :, 0] <= 5).all() or
|
|
||||||
(image['mask'][:, :, 0] >= 250).all()
|
|
||||||
)
|
|
||||||
):
|
|
||||||
logger.info("using mask as input")
|
|
||||||
input_image = HWC3(image['mask'][:, :, 0])
|
|
||||||
unit.module = 'none' # Always use black bg and white line
|
|
||||||
elif a1111_image is not None:
|
|
||||||
input_image = HWC3(np.asarray(a1111_image))
|
|
||||||
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
|
|
||||||
assert a1111_i2i_resize_mode is not None
|
|
||||||
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
|
|
||||||
|
|
||||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
|
||||||
if 'inpaint' in unit.module:
|
|
||||||
if a1111_mask_image is not None:
|
|
||||||
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
|
|
||||||
assert a1111_mask.ndim == 2
|
|
||||||
assert a1111_mask.shape[0] == input_image.shape[0]
|
|
||||||
assert a1111_mask.shape[1] == input_image.shape[1]
|
|
||||||
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
|
|
||||||
else:
|
|
||||||
input_image = np.concatenate([
|
|
||||||
input_image[:, :, 0:3],
|
|
||||||
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
|
|
||||||
], axis=2)
|
|
||||||
else:
|
|
||||||
raise ValueError("controlnet is enabled but no input image is given")
|
|
||||||
|
|
||||||
assert isinstance(input_image, (np.ndarray, list))
|
|
||||||
return input_image, resize_mode
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def try_crop_image_with_a1111_mask(
|
def try_crop_image_with_a1111_mask(
|
||||||
p: StableDiffusionProcessing,
|
p: StableDiffusionProcessing,
|
||||||
@ -202,19 +106,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
resize_mode: external_code.ResizeMode,
|
resize_mode: external_code.ResizeMode,
|
||||||
preprocessor
|
preprocessor
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
|
||||||
Crop ControlNet input image based on A1111 inpaint mask given.
|
|
||||||
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
|
|
||||||
to crop tiles.
|
|
||||||
"""
|
|
||||||
# Note: The method determining whether the active script is an upscale script is purely
|
|
||||||
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
|
|
||||||
# in the future.
|
|
||||||
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
|
|
||||||
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
|
|
||||||
logger.debug(f"is_upscale_script={is_upscale_script}")
|
|
||||||
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
|
|
||||||
# option is selected.
|
|
||||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||||
is_only_masked_inpaint = (
|
is_only_masked_inpaint = (
|
||||||
issubclass(type(p), StableDiffusionProcessingImg2Img) and
|
issubclass(type(p), StableDiffusionProcessingImg2Img) and
|
||||||
@ -224,9 +115,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
if (
|
if (
|
||||||
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
|
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
|
||||||
and is_only_masked_inpaint
|
and is_only_masked_inpaint
|
||||||
and (is_upscale_script or unit.inpaint_crop_input_image)
|
|
||||||
):
|
):
|
||||||
logger.debug("Crop input image based on A1111 mask.")
|
logger.info("Crop input image based on A1111 mask.")
|
||||||
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
|
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
|
||||||
input_image = [Image.fromarray(x) for x in input_image]
|
input_image = [Image.fromarray(x) for x in input_image]
|
||||||
|
|
||||||
@ -251,12 +141,47 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
return input_image
|
return input_image
|
||||||
|
|
||||||
def get_input_data(self, p, unit, preprocessor):
|
def get_input_data(self, p, unit, preprocessor):
|
||||||
mask = None
|
a1111_i2i_image = getattr(p, "init_images", [None])[0]
|
||||||
input_image, resize_mode = self.choose_input_image(p, unit)
|
a1111_i2i_mask = getattr(p, "image_mask", None)
|
||||||
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
|
|
||||||
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
|
using_a1111_data = False
|
||||||
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
|
|
||||||
return input_image, mask, resize_mode
|
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||||
|
|
||||||
|
if unit.use_preview_as_input and unit.generated_image is not None:
|
||||||
|
image = unit.generated_image
|
||||||
|
elif unit.image is None:
|
||||||
|
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
|
||||||
|
image = HWC3(np.asarray(a1111_i2i_image))
|
||||||
|
using_a1111_data = True
|
||||||
|
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
|
||||||
|
image = unit.image['mask']
|
||||||
|
else:
|
||||||
|
image = unit.image['image']
|
||||||
|
|
||||||
|
if not isinstance(image, np.ndarray):
|
||||||
|
raise ValueError("controlnet is enabled but no input image is given")
|
||||||
|
|
||||||
|
image = HWC3(image)
|
||||||
|
|
||||||
|
if using_a1111_data:
|
||||||
|
mask = HWC3(np.asarray(a1111_i2i_mask))
|
||||||
|
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
|
||||||
|
mask = unit.mask_image['image']
|
||||||
|
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
|
||||||
|
mask = unit.mask_image['mask']
|
||||||
|
elif unit.image is not None and (unit.image['mask'] > 5).any():
|
||||||
|
mask = unit.image['mask']
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
|
||||||
|
|
||||||
|
return image, mask, resize_mode
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
|
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
|
||||||
@ -296,9 +221,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
and getattr(p, 'enable_hr', False)
|
and getattr(p, 'enable_hr', False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if unit.use_preview_as_input:
|
||||||
|
unit.module = 'None'
|
||||||
|
|
||||||
preprocessor = global_state.get_preprocessor(unit.module)
|
preprocessor = global_state.get_preprocessor(unit.module)
|
||||||
|
|
||||||
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
|
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
|
||||||
|
# p.extra_result_images.append(input_image)
|
||||||
|
|
||||||
if unit.pixel_perfect:
|
if unit.pixel_perfect:
|
||||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||||
@ -339,10 +268,24 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
params.control_cond_for_hr_fix = preprocessor_output
|
params.control_cond_for_hr_fix = preprocessor_output
|
||||||
p.extra_result_images.append(input_image)
|
p.extra_result_images.append(input_image)
|
||||||
|
|
||||||
|
if input_mask is not None:
|
||||||
|
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||||
|
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||||
|
p.extra_result_images.append(params.control_mask)
|
||||||
|
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
|
||||||
|
|
||||||
|
if has_high_res_fix:
|
||||||
|
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||||
|
p.extra_result_images.append(params.control_mask_for_hr_fix)
|
||||||
|
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||||
|
else:
|
||||||
|
params.control_mask_for_hr_fix = params.control_mask
|
||||||
|
|
||||||
if preprocessor.do_not_need_model:
|
if preprocessor.do_not_need_model:
|
||||||
model_filename = 'Not Needed'
|
model_filename = 'Not Needed'
|
||||||
params.model = ControlModelPatcher()
|
params.model = ControlModelPatcher()
|
||||||
else:
|
else:
|
||||||
|
assert unit.model != 'None', 'You have not selected any control model!'
|
||||||
model_filename = global_state.get_controlnet_filename(unit.model)
|
model_filename = global_state.get_controlnet_filename(unit.model)
|
||||||
params.model = cached_controlnet_loader(model_filename)
|
params.model = cached_controlnet_loader(model_filename)
|
||||||
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
|
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
|
||||||
@ -366,16 +309,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
|
|
||||||
if is_hr_pass:
|
if is_hr_pass:
|
||||||
cond = params.control_cond_for_hr_fix
|
cond = params.control_cond_for_hr_fix
|
||||||
|
mask = params.control_mask_for_hr_fix
|
||||||
else:
|
else:
|
||||||
cond = params.control_cond
|
cond = params.control_cond
|
||||||
|
mask = params.control_mask
|
||||||
|
|
||||||
kwargs.update(dict(unit=unit, params=params))
|
kwargs.update(dict(unit=unit, params=params))
|
||||||
|
|
||||||
# CN inpaint fix
|
|
||||||
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
|
|
||||||
kwargs['cond_before_inpaint_fix'] = cond.clone()
|
|
||||||
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
|
|
||||||
|
|
||||||
params.model.strength = float(unit.weight)
|
params.model.strength = float(unit.weight)
|
||||||
params.model.start_percent = float(unit.guidance_start)
|
params.model.start_percent = float(unit.guidance_start)
|
||||||
params.model.end_percent = float(unit.guidance_end)
|
params.model.end_percent = float(unit.guidance_end)
|
||||||
@ -411,8 +351,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
|||||||
params.model.positive_advanced_weighting = soft_weighting.copy()
|
params.model.positive_advanced_weighting = soft_weighting.copy()
|
||||||
params.model.negative_advanced_weighting = soft_weighting.copy()
|
params.model.negative_advanced_weighting = soft_weighting.copy()
|
||||||
|
|
||||||
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs)
|
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||||
params.model.process_before_every_sampling(p, cond, *args, **kwargs)
|
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||||
|
|
||||||
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
|
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
|
||||||
return
|
return
|
||||||
|
@ -31,6 +31,7 @@ class Preprocessor:
|
|||||||
self.do_not_need_model = False
|
self.do_not_need_model = False
|
||||||
self.sorting_priority = 0 # higher goes to top in the list
|
self.sorting_priority = 0 # higher goes to top in the list
|
||||||
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
|
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
|
||||||
|
self.fill_mask_with_one_when_resize_and_fill = False
|
||||||
|
|
||||||
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
|
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
|
||||||
if load_device is None:
|
if load_device is None:
|
||||||
@ -59,8 +60,8 @@ class Preprocessor:
|
|||||||
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
|
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
|
||||||
return
|
return
|
||||||
|
|
||||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||||
return
|
return cond, mask
|
||||||
|
|
||||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user