ini mask support
This commit is contained in:
parent
07977a1fec
commit
9f6ee2a688
@ -29,10 +29,11 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
||||
self.image = None
|
||||
self.mask = None
|
||||
self.latent = None
|
||||
self.fill_mask_with_one_when_resize_and_fill = True
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3]
|
||||
self.mask = kwargs['cond_before_inpaint_fix'][:, 3:]
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
self.image = cond
|
||||
self.mask = mask
|
||||
|
||||
vae = process.sd_model.forge_objects.vae
|
||||
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
|
||||
@ -58,7 +59,10 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
||||
self.latent = latent_image
|
||||
return
|
||||
|
||||
mixed_cond = cond * (1.0 - mask) - mask
|
||||
|
||||
return mixed_cond, None
|
||||
|
||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||
a1111_batch_result = args[0]
|
||||
@ -95,19 +99,21 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
||||
self.setup_model_patcher(model)
|
||||
return
|
||||
|
||||
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
|
||||
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
|
||||
H, W, C = input_image.shape
|
||||
raw_color = input_image[:, :, 0:3].copy()
|
||||
raw_mask = input_image[:, :, 3:4].copy()
|
||||
raw_color = input_image.copy()
|
||||
raw_mask = input_mask.copy()
|
||||
|
||||
input_image, remove_pad = resize_image_with_pad(input_image, 256)
|
||||
input_mask, remove_pad = resize_image_with_pad(input_mask, 256)
|
||||
input_mask = input_mask[..., :1]
|
||||
|
||||
self.load_model()
|
||||
|
||||
self.move_all_model_patchers_to_gpu()
|
||||
|
||||
color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
|
||||
mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
|
||||
color = np.ascontiguousarray(input_image).astype(np.float32) / 255.0
|
||||
mask = np.ascontiguousarray(input_mask).astype(np.float32) / 255.0
|
||||
with torch.no_grad():
|
||||
color = self.send_tensor_to_model_device(torch.from_numpy(color))
|
||||
mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
|
||||
@ -128,15 +134,14 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
|
||||
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
|
||||
fin_color = fin_color.clip(0, 255).astype(np.uint8)
|
||||
|
||||
result = np.concatenate([fin_color, raw_mask], axis=2)
|
||||
return result
|
||||
return fin_color
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
super().process_before_every_sampling(process, cond, *args, **kwargs)
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
|
||||
sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max
|
||||
original_noise = kwargs['noise']
|
||||
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
|
||||
return
|
||||
return cond, mask
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorInpaint())
|
||||
|
@ -38,9 +38,9 @@ class PreprocessorRecolor(Preprocessor):
|
||||
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
|
||||
return result
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
self.current_cond = cond
|
||||
return
|
||||
return cond, mask
|
||||
|
||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||
a1111_batch_result = args[0]
|
||||
|
@ -47,7 +47,7 @@ class PreprocessorReference(Preprocessor):
|
||||
self.recorded_attn1 = {}
|
||||
self.recorded_h = {}
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
unit = kwargs['unit']
|
||||
weight = float(unit.weight)
|
||||
style_fidelity = float(unit.threshold_a)
|
||||
@ -190,7 +190,7 @@ class PreprocessorReference(Preprocessor):
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
||||
return
|
||||
return cond, mask
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorReference(
|
||||
|
@ -63,7 +63,7 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
|
||||
self.slider_1 = PreprocessorParameter(
|
||||
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
unit = kwargs['unit']
|
||||
|
||||
weight = float(unit.weight)
|
||||
@ -84,7 +84,8 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
|
||||
unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True)
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
return
|
||||
|
||||
return cond, mask
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorClipVisionForRevision(
|
||||
|
@ -37,7 +37,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
|
||||
self.variation = 8
|
||||
self.sharpness = None
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
self.variation = int(kwargs['unit'].threshold_a)
|
||||
|
||||
latent = self.register_latent(process, cond)
|
||||
@ -77,7 +77,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
|
||||
|
||||
process.sd_model.forge_objects.unet = unet
|
||||
|
||||
return
|
||||
return cond, mask
|
||||
|
||||
|
||||
class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
|
||||
@ -86,10 +86,9 @@ class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
|
||||
self.name = 'tile_colorfix+sharp'
|
||||
self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
self.sharpness = float(kwargs['unit'].threshold_b)
|
||||
super().process_before_every_sampling(process, cond, *args, **kwargs)
|
||||
return
|
||||
return super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
|
||||
|
||||
|
||||
add_supported_preprocessor(PreprocessorTile())
|
||||
|
@ -215,7 +215,6 @@ class ControlNetUiGroup(object):
|
||||
self.image_upload_panel = None
|
||||
self.save_detected_map = None
|
||||
self.input_mode = gr.State(InputMode.SIMPLE)
|
||||
self.inpaint_crop_input_image = None
|
||||
self.hr_option = None
|
||||
self.batch_image_dir_state = None
|
||||
self.output_dir_state = None
|
||||
@ -443,15 +442,6 @@ class ControlNetUiGroup(object):
|
||||
else:
|
||||
self.upload_independent_img_in_img2img = None
|
||||
|
||||
# Note: The checkbox needs to exist for both img2img and txt2img as infotext
|
||||
# needs the checkbox value.
|
||||
self.inpaint_crop_input_image = gr.Checkbox(
|
||||
label="Crop input image based on A1111 mask",
|
||||
value=False,
|
||||
elem_classes=["cnet-crop-input-image"],
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
|
||||
self.type_filter = gr.Radio(
|
||||
global_state.get_all_preprocessor_tags(),
|
||||
@ -805,18 +795,10 @@ class ControlNetUiGroup(object):
|
||||
)
|
||||
|
||||
img = HWC3(image["image"])
|
||||
has_mask = not (
|
||||
(image["mask"][:, :, 0] <= 5).all()
|
||||
or (image["mask"][:, :, 0] >= 250).all()
|
||||
)
|
||||
if "inpaint" in module:
|
||||
color = HWC3(image["image"])
|
||||
alpha = image["mask"][:, :, 0:1]
|
||||
img = np.concatenate([color, alpha], axis=2)
|
||||
elif has_mask and not shared.opts.data.get(
|
||||
"controlnet_ignore_noninpaint_mask", False
|
||||
):
|
||||
img = HWC3(image["mask"][:, :, 0])
|
||||
mask = HWC3(image["mask"])
|
||||
|
||||
if not (mask > 5).any():
|
||||
mask = None
|
||||
|
||||
preprocessor = global_state.get_preprocessor(module)
|
||||
|
||||
@ -854,6 +836,7 @@ class ControlNetUiGroup(object):
|
||||
resolution=pres,
|
||||
slider_1=pthr_a,
|
||||
slider_2=pthr_b,
|
||||
input_mask=mask,
|
||||
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
|
||||
json_pose_callback=json_acceptor.accept
|
||||
if is_openpose(module)
|
||||
@ -976,52 +959,9 @@ class ControlNetUiGroup(object):
|
||||
)
|
||||
|
||||
def register_shift_crop_input_image(self):
|
||||
# A1111 < 1.7.0 compatibility.
|
||||
if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs):
|
||||
self.inpaint_crop_input_image.visible = True
|
||||
self.inpaint_crop_input_image.value = True
|
||||
return
|
||||
|
||||
is_inpaint_tab = gr.State(False)
|
||||
|
||||
def shift_crop_input_image(is_inpaint: bool, inpaint_area: int):
|
||||
# Note: inpaint_area (0: Whole picture, 1: Only masked)
|
||||
# By default set value to True, as most preprocessors need cropped result.
|
||||
return gr.update(value=True, visible=is_inpaint and inpaint_area == 1)
|
||||
|
||||
gradio_kwargs = dict(
|
||||
fn=shift_crop_input_image,
|
||||
inputs=[
|
||||
is_inpaint_tab,
|
||||
ControlNetUiGroup.a1111_context.img2img_inpaint_area,
|
||||
],
|
||||
outputs=[self.inpaint_crop_input_image],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs:
|
||||
elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then(
|
||||
**gradio_kwargs
|
||||
)
|
||||
|
||||
for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs:
|
||||
elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then(
|
||||
**gradio_kwargs
|
||||
)
|
||||
|
||||
ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs)
|
||||
return
|
||||
|
||||
def register_shift_hr_options(self):
|
||||
# # A1111 version < 1.6.0.
|
||||
# if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
|
||||
# return
|
||||
#
|
||||
# ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
|
||||
# fn=lambda checked: gr.update(visible=checked),
|
||||
# inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
|
||||
# outputs=[self.hr_option],
|
||||
# show_progress=False,
|
||||
# )
|
||||
return
|
||||
|
||||
def register_shift_upload_mask(self):
|
||||
|
@ -317,11 +317,6 @@ def high_quality_resize(x, size):
|
||||
# Written by lvmin
|
||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||
|
||||
inpaint_mask = None
|
||||
if x.ndim == 3 and x.shape[2] == 4:
|
||||
inpaint_mask = x[:, :, 3]
|
||||
x = x[:, :, 0:3]
|
||||
|
||||
if x.shape[0] != size[1] or x.shape[1] != size[0]:
|
||||
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
||||
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
||||
@ -346,8 +341,6 @@ def high_quality_resize(x, size):
|
||||
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
||||
|
||||
y = cv2.resize(x, size, interpolation=interpolation)
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
||||
|
||||
if is_binary:
|
||||
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
||||
@ -361,15 +354,10 @@ def high_quality_resize(x, size):
|
||||
else:
|
||||
y = x
|
||||
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
||||
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
||||
y = np.concatenate([y, inpaint_mask], axis=2)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def crop_and_resize_image(detected_map, resize_mode, h, w):
|
||||
def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=False):
|
||||
if resize_mode == external_code.ResizeMode.RESIZE:
|
||||
detected_map = high_quality_resize(detected_map, (w, h))
|
||||
detected_map = safe_numpy(detected_map)
|
||||
@ -387,9 +375,8 @@ def crop_and_resize_image(detected_map, resize_mode, h, w):
|
||||
k = min(k0, k1)
|
||||
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
|
||||
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
|
||||
if len(high_quality_border_color) == 4:
|
||||
# Inpaint hijack
|
||||
high_quality_border_color[3] = 255
|
||||
if fill_border_with_255:
|
||||
high_quality_border_color = np.zeros_like(high_quality_border_color) + 255
|
||||
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
||||
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
||||
new_h, new_w, _ = detected_map.shape
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, List, Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.scripts as scripts
|
||||
@ -46,6 +47,8 @@ class ControlNetCachedParameters:
|
||||
self.model = None
|
||||
self.control_cond = None
|
||||
self.control_cond_for_hr_fix = None
|
||||
self.control_mask = None
|
||||
self.control_mask_for_hr_fix = None
|
||||
|
||||
|
||||
class ControlNetForForgeOfficial(scripts.Script):
|
||||
@ -95,105 +98,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
|
||||
def choose_input_image(
|
||||
self,
|
||||
p: processing.StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
) -> Tuple[np.ndarray, external_code.ResizeMode]:
|
||||
""" Choose input image from following sources with descending priority:
|
||||
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
|
||||
- unit.image: ControlNet tab input image.
|
||||
- p.init_images: A1111 img2img tab input image.
|
||||
|
||||
Returns:
|
||||
- The input image in ndarray form.
|
||||
- The resize mode.
|
||||
"""
|
||||
|
||||
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
|
||||
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
|
||||
unit_has_multiple_images = (
|
||||
isinstance(unit.image, list) and
|
||||
len(unit.image) > 0 and
|
||||
"image" in unit.image[0]
|
||||
)
|
||||
if unit_has_multiple_images:
|
||||
return [
|
||||
d
|
||||
for img in unit.image
|
||||
for d in (image_dict_from_any(img),)
|
||||
if d is not None
|
||||
]
|
||||
return image_dict_from_any(unit.image)
|
||||
|
||||
def decode_image(img) -> np.ndarray:
|
||||
"""Need to check the image for API compatibility."""
|
||||
if isinstance(img, str):
|
||||
return np.asarray(decode_base64_to_image(image['image']))
|
||||
else:
|
||||
assert isinstance(img, np.ndarray)
|
||||
return img
|
||||
|
||||
# 4 input image sources.
|
||||
image = parse_unit_image(unit)
|
||||
a1111_image = getattr(p, "init_images", [None])[0]
|
||||
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
|
||||
if image is not None:
|
||||
if isinstance(image, list):
|
||||
# Add mask logic if later there is a processor that accepts mask
|
||||
# on multiple inputs.
|
||||
input_image = [HWC3(decode_image(img['image'])) for img in image]
|
||||
else:
|
||||
input_image = HWC3(decode_image(image['image']))
|
||||
if 'mask' in image and image['mask'] is not None:
|
||||
while len(image['mask'].shape) < 3:
|
||||
image['mask'] = image['mask'][..., np.newaxis]
|
||||
if 'inpaint' in unit.module:
|
||||
logger.info("using inpaint as input")
|
||||
color = HWC3(image['image'])
|
||||
alpha = image['mask'][:, :, 0:1]
|
||||
input_image = np.concatenate([color, alpha], axis=2)
|
||||
elif (
|
||||
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
|
||||
# There is wield gradio issue that would produce mask that is
|
||||
# not pure color when no scribble is made on canvas.
|
||||
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
|
||||
not (
|
||||
(image['mask'][:, :, 0] <= 5).all() or
|
||||
(image['mask'][:, :, 0] >= 250).all()
|
||||
)
|
||||
):
|
||||
logger.info("using mask as input")
|
||||
input_image = HWC3(image['mask'][:, :, 0])
|
||||
unit.module = 'none' # Always use black bg and white line
|
||||
elif a1111_image is not None:
|
||||
input_image = HWC3(np.asarray(a1111_image))
|
||||
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
|
||||
assert a1111_i2i_resize_mode is not None
|
||||
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
|
||||
|
||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||
if 'inpaint' in unit.module:
|
||||
if a1111_mask_image is not None:
|
||||
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
|
||||
assert a1111_mask.ndim == 2
|
||||
assert a1111_mask.shape[0] == input_image.shape[0]
|
||||
assert a1111_mask.shape[1] == input_image.shape[1]
|
||||
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
|
||||
else:
|
||||
input_image = np.concatenate([
|
||||
input_image[:, :, 0:3],
|
||||
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
|
||||
], axis=2)
|
||||
else:
|
||||
raise ValueError("controlnet is enabled but no input image is given")
|
||||
|
||||
assert isinstance(input_image, (np.ndarray, list))
|
||||
return input_image, resize_mode
|
||||
|
||||
@staticmethod
|
||||
def try_crop_image_with_a1111_mask(
|
||||
p: StableDiffusionProcessing,
|
||||
@ -202,19 +106,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
resize_mode: external_code.ResizeMode,
|
||||
preprocessor
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Crop ControlNet input image based on A1111 inpaint mask given.
|
||||
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
|
||||
to crop tiles.
|
||||
"""
|
||||
# Note: The method determining whether the active script is an upscale script is purely
|
||||
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
|
||||
# in the future.
|
||||
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
|
||||
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
|
||||
logger.debug(f"is_upscale_script={is_upscale_script}")
|
||||
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
|
||||
# option is selected.
|
||||
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
|
||||
is_only_masked_inpaint = (
|
||||
issubclass(type(p), StableDiffusionProcessingImg2Img) and
|
||||
@ -224,9 +115,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if (
|
||||
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
|
||||
and is_only_masked_inpaint
|
||||
and (is_upscale_script or unit.inpaint_crop_input_image)
|
||||
):
|
||||
logger.debug("Crop input image based on A1111 mask.")
|
||||
logger.info("Crop input image based on A1111 mask.")
|
||||
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
|
||||
input_image = [Image.fromarray(x) for x in input_image]
|
||||
|
||||
@ -251,12 +141,47 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
return input_image
|
||||
|
||||
def get_input_data(self, p, unit, preprocessor):
|
||||
mask = None
|
||||
input_image, resize_mode = self.choose_input_image(p, unit)
|
||||
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
|
||||
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
|
||||
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
|
||||
return input_image, mask, resize_mode
|
||||
a1111_i2i_image = getattr(p, "init_images", [None])[0]
|
||||
a1111_i2i_mask = getattr(p, "image_mask", None)
|
||||
|
||||
using_a1111_data = False
|
||||
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
|
||||
if unit.use_preview_as_input and unit.generated_image is not None:
|
||||
image = unit.generated_image
|
||||
elif unit.image is None:
|
||||
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
|
||||
image = HWC3(np.asarray(a1111_i2i_image))
|
||||
using_a1111_data = True
|
||||
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
|
||||
image = unit.image['mask']
|
||||
else:
|
||||
image = unit.image['image']
|
||||
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError("controlnet is enabled but no input image is given")
|
||||
|
||||
image = HWC3(image)
|
||||
|
||||
if using_a1111_data:
|
||||
mask = HWC3(np.asarray(a1111_i2i_mask))
|
||||
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
|
||||
mask = unit.mask_image['image']
|
||||
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
|
||||
mask = unit.mask_image['mask']
|
||||
elif unit.image is not None and (unit.image['mask'] > 5).any():
|
||||
mask = unit.image['mask']
|
||||
else:
|
||||
mask = None
|
||||
|
||||
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
|
||||
|
||||
if mask is not None:
|
||||
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
|
||||
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
|
||||
|
||||
return image, mask, resize_mode
|
||||
|
||||
@staticmethod
|
||||
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
|
||||
@ -296,9 +221,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
and getattr(p, 'enable_hr', False)
|
||||
)
|
||||
|
||||
if unit.use_preview_as_input:
|
||||
unit.module = 'None'
|
||||
|
||||
preprocessor = global_state.get_preprocessor(unit.module)
|
||||
|
||||
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
|
||||
# p.extra_result_images.append(input_image)
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
@ -339,10 +268,24 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.control_cond_for_hr_fix = preprocessor_output
|
||||
p.extra_result_images.append(input_image)
|
||||
|
||||
if input_mask is not None:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
p.extra_result_images.append(params.control_mask)
|
||||
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
|
||||
|
||||
if has_high_res_fix:
|
||||
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
p.extra_result_images.append(params.control_mask_for_hr_fix)
|
||||
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
else:
|
||||
params.control_mask_for_hr_fix = params.control_mask
|
||||
|
||||
if preprocessor.do_not_need_model:
|
||||
model_filename = 'Not Needed'
|
||||
params.model = ControlModelPatcher()
|
||||
else:
|
||||
assert unit.model != 'None', 'You have not selected any control model!'
|
||||
model_filename = global_state.get_controlnet_filename(unit.model)
|
||||
params.model = cached_controlnet_loader(model_filename)
|
||||
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
|
||||
@ -366,16 +309,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
if is_hr_pass:
|
||||
cond = params.control_cond_for_hr_fix
|
||||
mask = params.control_mask_for_hr_fix
|
||||
else:
|
||||
cond = params.control_cond
|
||||
mask = params.control_mask
|
||||
|
||||
kwargs.update(dict(unit=unit, params=params))
|
||||
|
||||
# CN inpaint fix
|
||||
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
|
||||
kwargs['cond_before_inpaint_fix'] = cond.clone()
|
||||
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
|
||||
|
||||
params.model.strength = float(unit.weight)
|
||||
params.model.start_percent = float(unit.guidance_start)
|
||||
params.model.end_percent = float(unit.guidance_end)
|
||||
@ -411,8 +351,8 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.model.positive_advanced_weighting = soft_weighting.copy()
|
||||
params.model.negative_advanced_weighting = soft_weighting.copy()
|
||||
|
||||
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs)
|
||||
params.model.process_before_every_sampling(p, cond, *args, **kwargs)
|
||||
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||
|
||||
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
|
||||
return
|
||||
|
@ -31,6 +31,7 @@ class Preprocessor:
|
||||
self.do_not_need_model = False
|
||||
self.sorting_priority = 0 # higher goes to top in the list
|
||||
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
|
||||
self.fill_mask_with_one_when_resize_and_fill = False
|
||||
|
||||
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
|
||||
if load_device is None:
|
||||
@ -59,8 +60,8 @@ class Preprocessor:
|
||||
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
|
||||
return
|
||||
|
||||
def process_before_every_sampling(self, process, cond, *args, **kwargs):
|
||||
return
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
return cond, mask
|
||||
|
||||
def process_after_every_sampling(self, process, params, *args, **kwargs):
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user