From d5da7ec3270d1d869b240fbc0de50bc89f863b0d Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 29 Jan 2024 16:28:48 -0800 Subject: [PATCH] i --- .../sd_forge_controlnet/scripts/controlnet.py | 227 +++--------------- modules_forge/initialization.py | 3 + 2 files changed, 31 insertions(+), 199 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index d77d0c82..7708b3c9 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -14,7 +14,7 @@ from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.logging import logger from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing from lib_controlnet.infotext import Infotext -from modules_forge.forge_util import HWC3 +from modules_forge.forge_util import HWC3, numpy_to_pytorch import cv2 import numpy as np @@ -41,8 +41,10 @@ def cached_controlnet_loader(filename): class ControlNetCachedParameters: def __init__(self): - self.control_image = None - self.control_image_for_hr_fix = None + self.preprocessor = None + self.model = None + self.control_cond = None + self.control_cond_for_hr_fix = None class ControlNetForForgeOfficial(scripts.Script): @@ -389,197 +391,6 @@ class ControlNetForForgeOfficial(scripts.Script): return h, w, hr_y, hr_x - def controlnet_main_entry(self, p): - for idx, unit in enumerate(self.enabled_units): - - def preprocess_input_image(input_image: np.ndarray): - """ Preprocess single input image. """ - detected_map, is_image = self.preprocessor[unit.module]( - input_image, - res=unit.processor_res, - thr_a=unit.threshold_a, - thr_b=unit.threshold_b, - low_vram=( - ("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and - shared.opts.data.get("controlnet_clip_detector_on_cpu", False) - ), - ) - if high_res_fix: - if is_image: - hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x) - store_detected_map(hr_detected_map, unit.module) - else: - hr_control = detected_map - else: - hr_control = None - - if is_image: - control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w) - store_detected_map(detected_map, unit.module) - else: - control = detected_map - store_detected_map(input_image, unit.module) - - if control_model_type == ControlModelType.T2I_StyleAdapter: - control = control['last_hidden_state'] - - if control_model_type == ControlModelType.ReVision: - control = control['image_embeds'] - return control, hr_control - - controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images])) - if len(controls) == len(hr_controls) == 1: - control = controls[0] - hr_control = hr_controls[0] - else: - control = controls - hr_control = hr_controls - - preprocessor_dict = dict( - name=unit.module, - preprocessor_resolution=unit.processor_res, - threshold_a=unit.threshold_a, - threshold_b=unit.threshold_b - ) - - global_average_pooling = ( - control_model_type.is_controlnet() and - model_net.control_model.global_average_pooling - ) - control_mode = external_code.control_mode_from_value(unit.control_mode) - forward_param = ControlParams( - control_model=model_net, - preprocessor=preprocessor_dict, - hint_cond=control, - weight=unit.weight, - guidance_stopped=False, - start_guidance_percent=unit.guidance_start, - stop_guidance_percent=unit.guidance_end, - advanced_weighting=unit.advanced_weighting, - control_model_type=control_model_type, - global_average_pooling=global_average_pooling, - hr_hint_cond=hr_control, - hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH, - soft_injection=control_mode != external_code.ControlMode.BALANCED, - cfg_injection=control_mode == external_code.ControlMode.CONTROL, - ) - forward_params.append(forward_param) - - if 'inpaint_only' in unit.module: - final_inpaint_feed = hr_control if hr_control is not None else control - final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy() - final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy() - final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32) - final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32) - sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7) - final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8)) - final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None] - _, Hmask, Wmask = final_inpaint_mask.shape - final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy()) - final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy()) - - def inpaint_only_post_processing(x): - _, H, W = x.shape - if Hmask != H or Wmask != W: - logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') - return x - r = final_inpaint_raw.to(x.dtype).to(x.device) - m = final_inpaint_mask.to(x.dtype).to(x.device) - y = m * x.clip(0, 1) + (1 - m) * r - y = y.clip(0, 1) - return y - - post_processors.append(inpaint_only_post_processing) - - if 'recolor' in unit.module: - final_feed = hr_control if hr_control is not None else control - final_feed = final_feed.detach().cpu().numpy() - final_feed = np.ascontiguousarray(final_feed).copy() - final_feed = final_feed[0, 0, :, :].astype(np.float32) - final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8) - Hfeed, Wfeed = final_feed.shape - - if 'luminance' in unit.module: - - def recolor_luminance_post_processing(x): - C, H, W = x.shape - if Hfeed != H or Wfeed != W or C != 3: - logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') - return x - h = x.detach().cpu().numpy().transpose((1, 2, 0)) - h = (h * 255).clip(0, 255).astype(np.uint8) - h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB) - h[:, :, 0] = final_feed - h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB) - h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1)) - y = torch.from_numpy(h).clip(0, 1).to(x) - return y - - post_processors.append(recolor_luminance_post_processing) - - if 'intensity' in unit.module: - - def recolor_intensity_post_processing(x): - C, H, W = x.shape - if Hfeed != H or Wfeed != W or C != 3: - logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.') - return x - h = x.detach().cpu().numpy().transpose((1, 2, 0)) - h = (h * 255).clip(0, 255).astype(np.uint8) - h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV) - h[:, :, 2] = final_feed - h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB) - h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1)) - y = torch.from_numpy(h).clip(0, 1).to(x) - return y - - post_processors.append(recolor_intensity_post_processing) - - if '+lama' in unit.module: - forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control) - self.noise_modifier = forward_param.used_hint_cond_latent - - del model_net - - is_low_vram = any(unit.low_vram for unit in self.enabled_units) - - for i, param in enumerate(forward_params): - if param.control_model_type == ControlModelType.IPAdapter: - param.control_model.hook( - model=unet, - preprocessor_outputs=param.hint_cond, - weight=param.weight, - dtype=torch.float32, - start=param.start_guidance_percent, - end=param.stop_guidance_percent - ) - if param.control_model_type == ControlModelType.Controlllite: - param.control_model.hook( - model=unet, - cond=param.hint_cond, - weight=param.weight, - start=param.start_guidance_percent, - end=param.stop_guidance_percent - ) - if param.control_model_type == ControlModelType.InstantID: - # For instant_id we always expect ip-adapter model followed - # by ControlNet model. - assert i > 0, "InstantID control model should follow ipadapter model." - ip_adapter_param = forward_params[i - 1] - assert ip_adapter_param.control_model_type == ControlModelType.IPAdapter, \ - "InstantID control model should follow ipadapter model." - control_model = ip_adapter_param.control_model - assert hasattr(control_model, "image_emb") - param.control_context_override = control_model.image_emb - - self.latest_network = UnetHook(lowvram=is_low_vram) - self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p, - batch_option_uint_separate=batch_option_uint_separate, - batch_option_style_align=batch_option_style_align) - - self.detected_map = detected_maps - self.post_processors = post_processors - def process_unit_after_click_generate(self, p, unit, params, *args, **kwargs): h, w, hr_y, hr_x = self.get_target_dimensions(p) @@ -607,7 +418,9 @@ class ControlNetForForgeOfficial(scripts.Script): logger.info(f"Using preprocessor: {unit.module}") logger.info(f'preprocessor resolution = {unit.processor_res}') - detected_map = global_state.get_preprocessor(unit.module)( + preprocessor = global_state.get_preprocessor(unit.module) + + detected_map = preprocessor( input_image=input_image, resolution=unit.processor_res, slider_1=unit.threshold_a, @@ -617,11 +430,27 @@ class ControlNetForForgeOfficial(scripts.Script): detected_map_is_image = detected_map.ndim == 3 and detected_map.shape[2] < 5 if detected_map_is_image: - control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w) - store_detected_map(detected_map, unit.module) + params.control_cond = crop_and_resize_image(detected_map, resize_mode, h, w) + p.extra_result_images.append(params.control_cond) + params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1) + + if has_high_res_fix: + params.control_cond_for_hr_fix = crop_and_resize_image(detected_map, resize_mode, hr_y, hr_x) + p.extra_result_images.append(params.control_cond_for_hr_fix) + params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1) + else: + params.control_cond_for_hr_fix = params.control_cond else: - control = detected_map - store_detected_map(input_image, unit.module) + params.control_cond = detected_map + params.control_cond_for_hr_fix = detected_map + p.extra_result_images.append(input_image) + + params.preprocessor = preprocessor + + model_filename = global_state.get_controlnet_filename(unit.model) + params.model = cached_controlnet_loader(model_filename) + + logger.info(f"Current ControlNet: {model_filename}") return diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 2c9a4210..32c67fe7 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -24,4 +24,7 @@ def initialize_forge(): model_management.lowvram_available = True model_management.OOM_EXCEPTION = Exception + from modules_forge import supported_preprocessor + from modules_forge import supported_controlnet + return