i
This commit is contained in:
parent
b0885d21b7
commit
d5da7ec327
@ -14,7 +14,7 @@ from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.logging import logger
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
|
||||
from lib_controlnet.infotext import Infotext
|
||||
from modules_forge.forge_util import HWC3
|
||||
from modules_forge.forge_util import HWC3, numpy_to_pytorch
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -41,8 +41,10 @@ def cached_controlnet_loader(filename):
|
||||
|
||||
class ControlNetCachedParameters:
|
||||
def __init__(self):
|
||||
self.control_image = None
|
||||
self.control_image_for_hr_fix = None
|
||||
self.preprocessor = None
|
||||
self.model = None
|
||||
self.control_cond = None
|
||||
self.control_cond_for_hr_fix = None
|
||||
|
||||
|
||||
class ControlNetForForgeOfficial(scripts.Script):
|
||||
@ -389,197 +391,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
return h, w, hr_y, hr_x
|
||||
|
||||
def controlnet_main_entry(self, p):
|
||||
for idx, unit in enumerate(self.enabled_units):
|
||||
|
||||
def preprocess_input_image(input_image: np.ndarray):
|
||||
""" Preprocess single input image. """
|
||||
detected_map, is_image = self.preprocessor[unit.module](
|
||||
input_image,
|
||||
res=unit.processor_res,
|
||||
thr_a=unit.threshold_a,
|
||||
thr_b=unit.threshold_b,
|
||||
low_vram=(
|
||||
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
|
||||
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
|
||||
),
|
||||
)
|
||||
if high_res_fix:
|
||||
if is_image:
|
||||
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
|
||||
store_detected_map(hr_detected_map, unit.module)
|
||||
else:
|
||||
hr_control = detected_map
|
||||
else:
|
||||
hr_control = None
|
||||
|
||||
if is_image:
|
||||
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
|
||||
store_detected_map(detected_map, unit.module)
|
||||
else:
|
||||
control = detected_map
|
||||
store_detected_map(input_image, unit.module)
|
||||
|
||||
if control_model_type == ControlModelType.T2I_StyleAdapter:
|
||||
control = control['last_hidden_state']
|
||||
|
||||
if control_model_type == ControlModelType.ReVision:
|
||||
control = control['image_embeds']
|
||||
return control, hr_control
|
||||
|
||||
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images]))
|
||||
if len(controls) == len(hr_controls) == 1:
|
||||
control = controls[0]
|
||||
hr_control = hr_controls[0]
|
||||
else:
|
||||
control = controls
|
||||
hr_control = hr_controls
|
||||
|
||||
preprocessor_dict = dict(
|
||||
name=unit.module,
|
||||
preprocessor_resolution=unit.processor_res,
|
||||
threshold_a=unit.threshold_a,
|
||||
threshold_b=unit.threshold_b
|
||||
)
|
||||
|
||||
global_average_pooling = (
|
||||
control_model_type.is_controlnet() and
|
||||
model_net.control_model.global_average_pooling
|
||||
)
|
||||
control_mode = external_code.control_mode_from_value(unit.control_mode)
|
||||
forward_param = ControlParams(
|
||||
control_model=model_net,
|
||||
preprocessor=preprocessor_dict,
|
||||
hint_cond=control,
|
||||
weight=unit.weight,
|
||||
guidance_stopped=False,
|
||||
start_guidance_percent=unit.guidance_start,
|
||||
stop_guidance_percent=unit.guidance_end,
|
||||
advanced_weighting=unit.advanced_weighting,
|
||||
control_model_type=control_model_type,
|
||||
global_average_pooling=global_average_pooling,
|
||||
hr_hint_cond=hr_control,
|
||||
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
|
||||
soft_injection=control_mode != external_code.ControlMode.BALANCED,
|
||||
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
|
||||
)
|
||||
forward_params.append(forward_param)
|
||||
|
||||
if 'inpaint_only' in unit.module:
|
||||
final_inpaint_feed = hr_control if hr_control is not None else control
|
||||
final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
|
||||
final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
|
||||
final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
|
||||
final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
|
||||
sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
|
||||
final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
|
||||
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
|
||||
_, Hmask, Wmask = final_inpaint_mask.shape
|
||||
final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
|
||||
final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
|
||||
|
||||
def inpaint_only_post_processing(x):
|
||||
_, H, W = x.shape
|
||||
if Hmask != H or Wmask != W:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
return x
|
||||
r = final_inpaint_raw.to(x.dtype).to(x.device)
|
||||
m = final_inpaint_mask.to(x.dtype).to(x.device)
|
||||
y = m * x.clip(0, 1) + (1 - m) * r
|
||||
y = y.clip(0, 1)
|
||||
return y
|
||||
|
||||
post_processors.append(inpaint_only_post_processing)
|
||||
|
||||
if 'recolor' in unit.module:
|
||||
final_feed = hr_control if hr_control is not None else control
|
||||
final_feed = final_feed.detach().cpu().numpy()
|
||||
final_feed = np.ascontiguousarray(final_feed).copy()
|
||||
final_feed = final_feed[0, 0, :, :].astype(np.float32)
|
||||
final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
|
||||
Hfeed, Wfeed = final_feed.shape
|
||||
|
||||
if 'luminance' in unit.module:
|
||||
|
||||
def recolor_luminance_post_processing(x):
|
||||
C, H, W = x.shape
|
||||
if Hfeed != H or Wfeed != W or C != 3:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
return x
|
||||
h = x.detach().cpu().numpy().transpose((1, 2, 0))
|
||||
h = (h * 255).clip(0, 255).astype(np.uint8)
|
||||
h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
|
||||
h[:, :, 0] = final_feed
|
||||
h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
|
||||
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
|
||||
y = torch.from_numpy(h).clip(0, 1).to(x)
|
||||
return y
|
||||
|
||||
post_processors.append(recolor_luminance_post_processing)
|
||||
|
||||
if 'intensity' in unit.module:
|
||||
|
||||
def recolor_intensity_post_processing(x):
|
||||
C, H, W = x.shape
|
||||
if Hfeed != H or Wfeed != W or C != 3:
|
||||
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
|
||||
return x
|
||||
h = x.detach().cpu().numpy().transpose((1, 2, 0))
|
||||
h = (h * 255).clip(0, 255).astype(np.uint8)
|
||||
h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
|
||||
h[:, :, 2] = final_feed
|
||||
h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
|
||||
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
|
||||
y = torch.from_numpy(h).clip(0, 1).to(x)
|
||||
return y
|
||||
|
||||
post_processors.append(recolor_intensity_post_processing)
|
||||
|
||||
if '+lama' in unit.module:
|
||||
forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
|
||||
self.noise_modifier = forward_param.used_hint_cond_latent
|
||||
|
||||
del model_net
|
||||
|
||||
is_low_vram = any(unit.low_vram for unit in self.enabled_units)
|
||||
|
||||
for i, param in enumerate(forward_params):
|
||||
if param.control_model_type == ControlModelType.IPAdapter:
|
||||
param.control_model.hook(
|
||||
model=unet,
|
||||
preprocessor_outputs=param.hint_cond,
|
||||
weight=param.weight,
|
||||
dtype=torch.float32,
|
||||
start=param.start_guidance_percent,
|
||||
end=param.stop_guidance_percent
|
||||
)
|
||||
if param.control_model_type == ControlModelType.Controlllite:
|
||||
param.control_model.hook(
|
||||
model=unet,
|
||||
cond=param.hint_cond,
|
||||
weight=param.weight,
|
||||
start=param.start_guidance_percent,
|
||||
end=param.stop_guidance_percent
|
||||
)
|
||||
if param.control_model_type == ControlModelType.InstantID:
|
||||
# For instant_id we always expect ip-adapter model followed
|
||||
# by ControlNet model.
|
||||
assert i > 0, "InstantID control model should follow ipadapter model."
|
||||
ip_adapter_param = forward_params[i - 1]
|
||||
assert ip_adapter_param.control_model_type == ControlModelType.IPAdapter, \
|
||||
"InstantID control model should follow ipadapter model."
|
||||
control_model = ip_adapter_param.control_model
|
||||
assert hasattr(control_model, "image_emb")
|
||||
param.control_context_override = control_model.image_emb
|
||||
|
||||
self.latest_network = UnetHook(lowvram=is_low_vram)
|
||||
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p,
|
||||
batch_option_uint_separate=batch_option_uint_separate,
|
||||
batch_option_style_align=batch_option_style_align)
|
||||
|
||||
self.detected_map = detected_maps
|
||||
self.post_processors = post_processors
|
||||
|
||||
def process_unit_after_click_generate(self, p, unit, params, *args, **kwargs):
|
||||
h, w, hr_y, hr_x = self.get_target_dimensions(p)
|
||||
|
||||
@ -607,7 +418,9 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
logger.info(f"Using preprocessor: {unit.module}")
|
||||
logger.info(f'preprocessor resolution = {unit.processor_res}')
|
||||
|
||||
detected_map = global_state.get_preprocessor(unit.module)(
|
||||
preprocessor = global_state.get_preprocessor(unit.module)
|
||||
|
||||
detected_map = preprocessor(
|
||||
input_image=input_image,
|
||||
resolution=unit.processor_res,
|
||||
slider_1=unit.threshold_a,
|
||||
@ -617,11 +430,27 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
detected_map_is_image = detected_map.ndim == 3 and detected_map.shape[2] < 5
|
||||
|
||||
if detected_map_is_image:
|
||||
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
|
||||
store_detected_map(detected_map, unit.module)
|
||||
params.control_cond = crop_and_resize_image(detected_map, resize_mode, h, w)
|
||||
p.extra_result_images.append(params.control_cond)
|
||||
params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1)
|
||||
|
||||
if has_high_res_fix:
|
||||
params.control_cond_for_hr_fix = crop_and_resize_image(detected_map, resize_mode, hr_y, hr_x)
|
||||
p.extra_result_images.append(params.control_cond_for_hr_fix)
|
||||
params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1)
|
||||
else:
|
||||
params.control_cond_for_hr_fix = params.control_cond
|
||||
else:
|
||||
control = detected_map
|
||||
store_detected_map(input_image, unit.module)
|
||||
params.control_cond = detected_map
|
||||
params.control_cond_for_hr_fix = detected_map
|
||||
p.extra_result_images.append(input_image)
|
||||
|
||||
params.preprocessor = preprocessor
|
||||
|
||||
model_filename = global_state.get_controlnet_filename(unit.model)
|
||||
params.model = cached_controlnet_loader(model_filename)
|
||||
|
||||
logger.info(f"Current ControlNet: {model_filename}")
|
||||
|
||||
return
|
||||
|
||||
|
@ -24,4 +24,7 @@ def initialize_forge():
|
||||
model_management.lowvram_available = True
|
||||
model_management.OOM_EXCEPTION = Exception
|
||||
|
||||
from modules_forge import supported_preprocessor
|
||||
from modules_forge import supported_controlnet
|
||||
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user