From 539bc5035d6913154d266f02965edada45acd0f4 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 22 Feb 2024 01:28:38 -0800 Subject: [PATCH] safe cleanup to avoid potential problems --- ldm_patched/modules/ops.py | 16 ++++++++++++++++ modules_forge/forge_sampler.py | 4 ++-- modules_forge/stream.py | 6 ------ 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index fdbaa064..8cc41123 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -71,6 +71,22 @@ def main_thread_worker(weight, bias, signal): return +def cleanup_cache(): + global gc + + if stream.current_stream is not None: + with stream.stream_context()(stream.current_stream): + for k, (w, b, s) in gc.items(): + stream.current_stream.wait_event(s) + stream.current_stream.synchronize() + + gc.clear() + + if stream.mover_stream is not None: + stream.mover_stream.synchronize() + return + + class disable_weight_init: class Linear(torch.nn.Linear): ldm_patched_cast_weights = False diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index de5fb72c..79524dff 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -2,7 +2,7 @@ import torch from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn from ldm_patched.modules.samplers import sampling_function from ldm_patched.modules import model_management -from modules_forge.stream import synchronize_current_stream +from ldm_patched.modules.ops import cleanup_cache def cond_from_a1111_to_patched_ldm(cond): @@ -114,5 +114,5 @@ def sampling_prepare(unet, x): def sampling_cleanup(unet): for cnet in unet.list_controlnets(): cnet.cleanup() - synchronize_current_stream() + cleanup_cache() return diff --git a/modules_forge/stream.py b/modules_forge/stream.py index ccb9784c..2ca067e2 100644 --- a/modules_forge/stream.py +++ b/modules_forge/stream.py @@ -40,12 +40,6 @@ def get_new_stream(): return None -def synchronize_current_stream(): - global current_stream - if current_stream is not None: - current_stream.synchronize() - - if shared.opts.use_non_streamlined_lowvram: current_stream = None mover_stream = None