From 26c325296e3e627732169b8086e4bfadd6ccf83a Mon Sep 17 00:00:00 2001 From: lllyasviel <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 23 Feb 2024 09:24:39 -0800 Subject: [PATCH] rework memory computation for async loader (#377) --- ldm_patched/modules/model_management.py | 45 ++++++++++++++++++------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 711f34fe..12bad2e9 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -291,9 +291,11 @@ class LoadedModel: else: return self.model_memory() - def model_load(self, lowvram_model_memory=0): + def model_load(self, async_kept_memory=-1): patch_model_to = None - if lowvram_model_memory == 0: + disable_async_load = async_kept_memory < 0 + + if disable_async_load: patch_model_to = self.device self.model.model_patches_to(self.device) @@ -306,23 +308,29 @@ class LoadedModel: self.model_unload() raise e - if lowvram_model_memory > 0: - print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) + if not disable_async_load: + print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024)) + real_async_memory = 0 + real_kept_memory = 0 mem_counter = 0 for m in self.real_model.modules(): if hasattr(m, "ldm_patched_cast_weights"): m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights m.ldm_patched_cast_weights = True module_mem = module_size(m) - if mem_counter + module_mem < lowvram_model_memory: + if mem_counter + module_mem < async_kept_memory: m.to(self.device) mem_counter += module_mem + real_kept_memory += module_mem else: + real_async_memory += module_mem m._apply(lambda x: x.pin_memory()) - elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode + elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) - print("lowvram: loaded module regularly", m) + print("[Memory Management] Async Loader Disabled for ", m) + print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024)) + print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", real_kept_memory / (1024 * 1024)) self.model_accelerated = True @@ -433,20 +441,31 @@ def load_models_gpu(models, memory_required=0): vram_set_state = VRAMState.DISABLED else: vram_set_state = vram_state - lowvram_model_memory = 0 + + async_kept_memory = 0 + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 2 * 1024 * (1024 * 1024)) / 1.3)) - if model_size > (current_free_mem - inference_memory): + estimated_memory_remaining = current_free_mem - model_size - extra_mem + + print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024)) + print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024)) + print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024)) + print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024)) + + if estimated_memory_remaining < 0: vram_set_state = VRAMState.LOW_VRAM + async_overhead_memory = 1024 * 1024 * 1024 + async_kept_memory = current_free_mem - extra_mem - async_overhead_memory + async_kept_memory = int(max(0, async_kept_memory)) else: - lowvram_model_memory = 0 + async_kept_memory = -1 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 64 * 1024 * 1024 + async_kept_memory = 0 - cur_loaded_model = loaded_model.model_load(lowvram_model_memory) + cur_loaded_model = loaded_model.model_load(async_kept_memory) current_loaded_models.insert(0, loaded_model) moving_time = time.perf_counter() - execution_start_time