This commit is contained in:
lllyasviel 2024-01-29 16:28:48 -08:00
parent b0885d21b7
commit d5da7ec327
2 changed files with 31 additions and 199 deletions

View File

@ -14,7 +14,7 @@ from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, StableDiffusionProcessing
from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3
from modules_forge.forge_util import HWC3, numpy_to_pytorch
import cv2
import numpy as np
@ -41,8 +41,10 @@ def cached_controlnet_loader(filename):
class ControlNetCachedParameters:
def __init__(self):
self.control_image = None
self.control_image_for_hr_fix = None
self.preprocessor = None
self.model = None
self.control_cond = None
self.control_cond_for_hr_fix = None
class ControlNetForForgeOfficial(scripts.Script):
@ -389,197 +391,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return h, w, hr_y, hr_x
def controlnet_main_entry(self, p):
for idx, unit in enumerate(self.enabled_units):
def preprocess_input_image(input_image: np.ndarray):
""" Preprocess single input image. """
detected_map, is_image = self.preprocessor[unit.module](
input_image,
res=unit.processor_res,
thr_a=unit.threshold_a,
thr_b=unit.threshold_b,
low_vram=(
("clip" in unit.module or unit.module == "ip-adapter_face_id_plus") and
shared.opts.data.get("controlnet_clip_detector_on_cpu", False)
),
)
if high_res_fix:
if is_image:
hr_control, hr_detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
store_detected_map(hr_detected_map, unit.module)
else:
hr_control = detected_map
else:
hr_control = None
if is_image:
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
store_detected_map(detected_map, unit.module)
else:
control = detected_map
store_detected_map(input_image, unit.module)
if control_model_type == ControlModelType.T2I_StyleAdapter:
control = control['last_hidden_state']
if control_model_type == ControlModelType.ReVision:
control = control['image_embeds']
return control, hr_control
controls, hr_controls = list(zip(*[preprocess_input_image(img) for img in input_images]))
if len(controls) == len(hr_controls) == 1:
control = controls[0]
hr_control = hr_controls[0]
else:
control = controls
hr_control = hr_controls
preprocessor_dict = dict(
name=unit.module,
preprocessor_resolution=unit.processor_res,
threshold_a=unit.threshold_a,
threshold_b=unit.threshold_b
)
global_average_pooling = (
control_model_type.is_controlnet() and
model_net.control_model.global_average_pooling
)
control_mode = external_code.control_mode_from_value(unit.control_mode)
forward_param = ControlParams(
control_model=model_net,
preprocessor=preprocessor_dict,
hint_cond=control,
weight=unit.weight,
guidance_stopped=False,
start_guidance_percent=unit.guidance_start,
stop_guidance_percent=unit.guidance_end,
advanced_weighting=unit.advanced_weighting,
control_model_type=control_model_type,
global_average_pooling=global_average_pooling,
hr_hint_cond=hr_control,
hr_option=HiResFixOption.from_value(unit.hr_option) if high_res_fix else HiResFixOption.BOTH,
soft_injection=control_mode != external_code.ControlMode.BALANCED,
cfg_injection=control_mode == external_code.ControlMode.CONTROL,
)
forward_params.append(forward_param)
if 'inpaint_only' in unit.module:
final_inpaint_feed = hr_control if hr_control is not None else control
final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
_, Hmask, Wmask = final_inpaint_mask.shape
final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
def inpaint_only_post_processing(x):
_, H, W = x.shape
if Hmask != H or Wmask != W:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
r = final_inpaint_raw.to(x.dtype).to(x.device)
m = final_inpaint_mask.to(x.dtype).to(x.device)
y = m * x.clip(0, 1) + (1 - m) * r
y = y.clip(0, 1)
return y
post_processors.append(inpaint_only_post_processing)
if 'recolor' in unit.module:
final_feed = hr_control if hr_control is not None else control
final_feed = final_feed.detach().cpu().numpy()
final_feed = np.ascontiguousarray(final_feed).copy()
final_feed = final_feed[0, 0, :, :].astype(np.float32)
final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
Hfeed, Wfeed = final_feed.shape
if 'luminance' in unit.module:
def recolor_luminance_post_processing(x):
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
h[:, :, 0] = final_feed
h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
return y
post_processors.append(recolor_luminance_post_processing)
if 'intensity' in unit.module:
def recolor_intensity_post_processing(x):
C, H, W = x.shape
if Hfeed != H or Wfeed != W or C != 3:
logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
return x
h = x.detach().cpu().numpy().transpose((1, 2, 0))
h = (h * 255).clip(0, 255).astype(np.uint8)
h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
h[:, :, 2] = final_feed
h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
y = torch.from_numpy(h).clip(0, 1).to(x)
return y
post_processors.append(recolor_intensity_post_processing)
if '+lama' in unit.module:
forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
self.noise_modifier = forward_param.used_hint_cond_latent
del model_net
is_low_vram = any(unit.low_vram for unit in self.enabled_units)
for i, param in enumerate(forward_params):
if param.control_model_type == ControlModelType.IPAdapter:
param.control_model.hook(
model=unet,
preprocessor_outputs=param.hint_cond,
weight=param.weight,
dtype=torch.float32,
start=param.start_guidance_percent,
end=param.stop_guidance_percent
)
if param.control_model_type == ControlModelType.Controlllite:
param.control_model.hook(
model=unet,
cond=param.hint_cond,
weight=param.weight,
start=param.start_guidance_percent,
end=param.stop_guidance_percent
)
if param.control_model_type == ControlModelType.InstantID:
# For instant_id we always expect ip-adapter model followed
# by ControlNet model.
assert i > 0, "InstantID control model should follow ipadapter model."
ip_adapter_param = forward_params[i - 1]
assert ip_adapter_param.control_model_type == ControlModelType.IPAdapter, \
"InstantID control model should follow ipadapter model."
control_model = ip_adapter_param.control_model
assert hasattr(control_model, "image_emb")
param.control_context_override = control_model.image_emb
self.latest_network = UnetHook(lowvram=is_low_vram)
self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p,
batch_option_uint_separate=batch_option_uint_separate,
batch_option_style_align=batch_option_style_align)
self.detected_map = detected_maps
self.post_processors = post_processors
def process_unit_after_click_generate(self, p, unit, params, *args, **kwargs):
h, w, hr_y, hr_x = self.get_target_dimensions(p)
@ -607,7 +418,9 @@ class ControlNetForForgeOfficial(scripts.Script):
logger.info(f"Using preprocessor: {unit.module}")
logger.info(f'preprocessor resolution = {unit.processor_res}')
detected_map = global_state.get_preprocessor(unit.module)(
preprocessor = global_state.get_preprocessor(unit.module)
detected_map = preprocessor(
input_image=input_image,
resolution=unit.processor_res,
slider_1=unit.threshold_a,
@ -617,11 +430,27 @@ class ControlNetForForgeOfficial(scripts.Script):
detected_map_is_image = detected_map.ndim == 3 and detected_map.shape[2] < 5
if detected_map_is_image:
control, detected_map = Script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
store_detected_map(detected_map, unit.module)
params.control_cond = crop_and_resize_image(detected_map, resize_mode, h, w)
p.extra_result_images.append(params.control_cond)
params.control_cond = numpy_to_pytorch(params.control_cond).movedim(-1, 1)
if has_high_res_fix:
params.control_cond_for_hr_fix = crop_and_resize_image(detected_map, resize_mode, hr_y, hr_x)
p.extra_result_images.append(params.control_cond_for_hr_fix)
params.control_cond_for_hr_fix = numpy_to_pytorch(params.control_cond_for_hr_fix).movedim(-1, 1)
else:
params.control_cond_for_hr_fix = params.control_cond
else:
control = detected_map
store_detected_map(input_image, unit.module)
params.control_cond = detected_map
params.control_cond_for_hr_fix = detected_map
p.extra_result_images.append(input_image)
params.preprocessor = preprocessor
model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename)
logger.info(f"Current ControlNet: {model_filename}")
return

View File

@ -24,4 +24,7 @@ def initialize_forge():
model_management.lowvram_available = True
model_management.OOM_EXCEPTION = Exception
from modules_forge import supported_preprocessor
from modules_forge import supported_controlnet
return