safe cleanup to avoid potential problems

This commit is contained in:
lllyasviel 2024-02-22 01:28:38 -08:00
parent 4080e25805
commit 539bc5035d
3 changed files with 18 additions and 8 deletions

View File

@ -71,6 +71,22 @@ def main_thread_worker(weight, bias, signal):
return
def cleanup_cache():
global gc
if stream.current_stream is not None:
with stream.stream_context()(stream.current_stream):
for k, (w, b, s) in gc.items():
stream.current_stream.wait_event(s)
stream.current_stream.synchronize()
gc.clear()
if stream.mover_stream is not None:
stream.mover_stream.synchronize()
return
class disable_weight_init:
class Linear(torch.nn.Linear):
ldm_patched_cast_weights = False

View File

@ -2,7 +2,7 @@ import torch
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
from ldm_patched.modules.samplers import sampling_function
from ldm_patched.modules import model_management
from modules_forge.stream import synchronize_current_stream
from ldm_patched.modules.ops import cleanup_cache
def cond_from_a1111_to_patched_ldm(cond):
@ -114,5 +114,5 @@ def sampling_prepare(unet, x):
def sampling_cleanup(unet):
for cnet in unet.list_controlnets():
cnet.cleanup()
synchronize_current_stream()
cleanup_cache()
return

View File

@ -40,12 +40,6 @@ def get_new_stream():
return None
def synchronize_current_stream():
global current_stream
if current_stream is not None:
current_stream.synchronize()
if shared.opts.use_non_streamlined_lowvram:
current_stream = None
mover_stream = None