rework memory computation for async loader (#377)

This commit is contained in:
lllyasviel 2024-02-23 09:24:39 -08:00 committed by GitHub
parent eacb14e115
commit 26c325296e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -291,9 +291,11 @@ class LoadedModel:
else:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
def model_load(self, async_kept_memory=-1):
patch_model_to = None
if lowvram_model_memory == 0:
disable_async_load = async_kept_memory < 0
if disable_async_load:
patch_model_to = self.device
self.model.model_patches_to(self.device)
@ -306,23 +308,29 @@ class LoadedModel:
self.model_unload()
raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
if not disable_async_load:
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
real_async_memory = 0
real_kept_memory = 0
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "ldm_patched_cast_weights"):
m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
m.ldm_patched_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
if mem_counter + module_mem < async_kept_memory:
m.to(self.device)
mem_counter += module_mem
real_kept_memory += module_mem
else:
real_async_memory += module_mem
m._apply(lambda x: x.pin_memory())
elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
print("[Memory Management] Async Loader Disabled for ", m)
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024))
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", real_kept_memory / (1024 * 1024))
self.model_accelerated = True
@ -433,20 +441,31 @@ def load_models_gpu(models, memory_required=0):
vram_set_state = VRAMState.DISABLED
else:
vram_set_state = vram_state
lowvram_model_memory = 0
async_kept_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 2 * 1024 * (1024 * 1024)) / 1.3))
if model_size > (current_free_mem - inference_memory):
estimated_memory_remaining = current_free_mem - model_size - extra_mem
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024))
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024))
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024))
if estimated_memory_remaining < 0:
vram_set_state = VRAMState.LOW_VRAM
async_overhead_memory = 1024 * 1024 * 1024
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
async_kept_memory = int(max(0, async_kept_memory))
else:
lowvram_model_memory = 0
async_kept_memory = -1
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
async_kept_memory = 0
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
cur_loaded_model = loaded_model.model_load(async_kept_memory)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time