From a1d2a31233d3ee0db2d8ba73caec4378c2294e7c Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Mon, 29 Jan 2024 23:28:13 -0800 Subject: [PATCH] more api --- .../sd_forge_controlnet/scripts/controlnet.py | 3 +++ modules_forge/forge_util.py | 8 ++++++++ modules_forge/supported_controlnet.py | 5 ++++- modules_forge/supported_preprocessor.py | 8 +------- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 714e88b1..b6fd1e8d 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -460,6 +460,9 @@ class ControlNetForForgeOfficial(scripts.Script): params.model = cached_controlnet_loader(model_filename) params.preprocessor = preprocessor + params.preprocessor.process_after_running_preprocessors(process=p, params=params, **kwargs) + params.model.process_after_running_preprocessors(process=p, params=params, **kwargs) + logger.info(f"Current ControlNet {type(params.model).__name__}: {model_filename}") return diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index c9c6c579..ba1e8bb5 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -6,6 +6,8 @@ import random import string import cv2 +from ldm_patched.modules import model_management + def HWC3(x): assert x.dtype == np.uint8 @@ -124,3 +126,9 @@ def resize_image_with_pad(img, resolution): return safer_memory(x[:H_target, :W_target]) return safer_memory(img_padded), remove_pad + + +def lazy_memory_management(model): + required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() + model_management.free_memory(required_memory, device=model_management.get_torch_device()) + return diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 67dc0488..90c3e59c 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -23,7 +23,10 @@ class ControlModelPatcher: self.advanced_frame_weighting = None self.advanced_sigma_weighting = None - def patch_to_process(self, p, control_image): + def process_after_running_preprocessors(self, process, params, *args, **kwargs): + return + + def process_before_every_sampling(self, process, cond, *args, **kwargs): return diff --git a/modules_forge/supported_preprocessor.py b/modules_forge/supported_preprocessor.py index 09721f00..2a0e2fab 100644 --- a/modules_forge/supported_preprocessor.py +++ b/modules_forge/supported_preprocessor.py @@ -53,13 +53,7 @@ class Preprocessor: def send_tensor_to_model_device(self, x): return x.to(device=self.model_patcher.current_device, dtype=self.model_patcher.dtype) - def lazy_memory_management(self, model): - # This is a lazy method to just free some memory - # so that we can still use old codes to manage memory in a bad way - # Ideally this should all be removed and all memory should be managed by model patcher. - # But the workload is too big, so we just use a quick method to manage in dirty way. - required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() - model_management.free_memory(required_memory, device=model_management.get_torch_device()) + def process_after_running_preprocessors(self, process, params, *args, **kwargs): return def process_before_every_sampling(self, process, cond, *args, **kwargs):