rework memory computation for async loader (#377)
This commit is contained in:
parent
eacb14e115
commit
26c325296e
@ -291,9 +291,11 @@ class LoadedModel:
|
|||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, async_kept_memory=-1):
|
||||||
patch_model_to = None
|
patch_model_to = None
|
||||||
if lowvram_model_memory == 0:
|
disable_async_load = async_kept_memory < 0
|
||||||
|
|
||||||
|
if disable_async_load:
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
|
|
||||||
self.model.model_patches_to(self.device)
|
self.model.model_patches_to(self.device)
|
||||||
@ -306,23 +308,29 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if lowvram_model_memory > 0:
|
if not disable_async_load:
|
||||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
|
||||||
|
real_async_memory = 0
|
||||||
|
real_kept_memory = 0
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
for m in self.real_model.modules():
|
for m in self.real_model.modules():
|
||||||
if hasattr(m, "ldm_patched_cast_weights"):
|
if hasattr(m, "ldm_patched_cast_weights"):
|
||||||
m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
|
m.prev_ldm_patched_cast_weights = m.ldm_patched_cast_weights
|
||||||
m.ldm_patched_cast_weights = True
|
m.ldm_patched_cast_weights = True
|
||||||
module_mem = module_size(m)
|
module_mem = module_size(m)
|
||||||
if mem_counter + module_mem < lowvram_model_memory:
|
if mem_counter + module_mem < async_kept_memory:
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
|
real_kept_memory += module_mem
|
||||||
else:
|
else:
|
||||||
|
real_async_memory += module_mem
|
||||||
m._apply(lambda x: x.pin_memory())
|
m._apply(lambda x: x.pin_memory())
|
||||||
elif hasattr(m, "weight"): #only modules with ldm_patched_cast_weights can be set to lowvram mode
|
elif hasattr(m, "weight"):
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_size(m)
|
mem_counter += module_size(m)
|
||||||
print("lowvram: loaded module regularly", m)
|
print("[Memory Management] Async Loader Disabled for ", m)
|
||||||
|
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024))
|
||||||
|
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", real_kept_memory / (1024 * 1024))
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
@ -433,20 +441,31 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
vram_set_state = VRAMState.DISABLED
|
vram_set_state = VRAMState.DISABLED
|
||||||
else:
|
else:
|
||||||
vram_set_state = vram_state
|
vram_set_state = vram_state
|
||||||
lowvram_model_memory = 0
|
|
||||||
|
async_kept_memory = 0
|
||||||
|
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 2 * 1024 * (1024 * 1024)) / 1.3))
|
estimated_memory_remaining = current_free_mem - model_size - extra_mem
|
||||||
if model_size > (current_free_mem - inference_memory):
|
|
||||||
|
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024))
|
||||||
|
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024))
|
||||||
|
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024))
|
||||||
|
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024))
|
||||||
|
|
||||||
|
if estimated_memory_remaining < 0:
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
|
async_overhead_memory = 1024 * 1024 * 1024
|
||||||
|
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
|
||||||
|
async_kept_memory = int(max(0, async_kept_memory))
|
||||||
else:
|
else:
|
||||||
lowvram_model_memory = 0
|
async_kept_memory = -1
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
async_kept_memory = 0
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
cur_loaded_model = loaded_model.model_load(async_kept_memory)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
|
||||||
moving_time = time.perf_counter() - execution_start_time
|
moving_time = time.perf_counter() - execution_start_time
|
||||||
|
Loading…
Reference in New Issue
Block a user