This commit is contained in:
lllyasviel 2024-01-29 23:28:13 -08:00
parent 8ee804a159
commit a1d2a31233
4 changed files with 16 additions and 8 deletions

View File

@ -460,6 +460,9 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model = cached_controlnet_loader(model_filename)
params.preprocessor = preprocessor
params.preprocessor.process_after_running_preprocessors(process=p, params=params, **kwargs)
params.model.process_after_running_preprocessors(process=p, params=params, **kwargs)
logger.info(f"Current ControlNet {type(params.model).__name__}: {model_filename}")
return

View File

@ -6,6 +6,8 @@ import random
import string
import cv2
from ldm_patched.modules import model_management
def HWC3(x):
assert x.dtype == np.uint8
@ -124,3 +126,9 @@ def resize_image_with_pad(img, resolution):
return safer_memory(x[:H_target, :W_target])
return safer_memory(img_padded), remove_pad
def lazy_memory_management(model):
required_memory = model_management.module_size(model) + model_management.minimum_inference_memory()
model_management.free_memory(required_memory, device=model_management.get_torch_device())
return

View File

@ -23,7 +23,10 @@ class ControlModelPatcher:
self.advanced_frame_weighting = None
self.advanced_sigma_weighting = None
def patch_to_process(self, p, control_image):
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
return
def process_before_every_sampling(self, process, cond, *args, **kwargs):
return

View File

@ -53,13 +53,7 @@ class Preprocessor:
def send_tensor_to_model_device(self, x):
return x.to(device=self.model_patcher.current_device, dtype=self.model_patcher.dtype)
def lazy_memory_management(self, model):
# This is a lazy method to just free some memory
# so that we can still use old codes to manage memory in a bad way
# Ideally this should all be removed and all memory should be managed by model patcher.
# But the workload is too big, so we just use a quick method to manage in dirty way.
required_memory = model_management.module_size(model) + model_management.minimum_inference_memory()
model_management.free_memory(required_memory, device=model_management.get_torch_device())
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
return
def process_before_every_sampling(self, process, cond, *args, **kwargs):