2024-01-29 17:28:38 +00:00
|
|
|
import time
|
2024-01-16 10:33:39 +00:00
|
|
|
import torch
|
|
|
|
import contextlib
|
2024-01-29 15:47:56 +00:00
|
|
|
from ldm_patched.modules import model_management
|
2024-02-09 02:24:04 +00:00
|
|
|
from ldm_patched.modules.ops import use_patched_ops
|
2024-01-29 04:39:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
2024-01-29 15:47:56 +00:00
|
|
|
def automatic_memory_management():
|
|
|
|
model_management.free_memory(
|
|
|
|
memory_required=3 * 1024 * 1024 * 1024,
|
|
|
|
device=model_management.get_torch_device()
|
|
|
|
)
|
|
|
|
|
2024-01-29 04:39:00 +00:00
|
|
|
module_list = []
|
2024-01-29 15:47:56 +00:00
|
|
|
|
|
|
|
original_init = torch.nn.Module.__init__
|
|
|
|
original_to = torch.nn.Module.to
|
2024-01-29 04:39:00 +00:00
|
|
|
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
|
|
module_list.append(self)
|
2024-01-29 15:47:56 +00:00
|
|
|
return original_init(self, *args, **kwargs)
|
|
|
|
|
|
|
|
def patched_to(self, *args, **kwargs):
|
|
|
|
module_list.append(self)
|
|
|
|
return original_to(self, *args, **kwargs)
|
2024-01-29 04:39:00 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
torch.nn.Module.__init__ = patched_init
|
2024-01-29 15:47:56 +00:00
|
|
|
torch.nn.Module.to = patched_to
|
2024-01-29 04:39:00 +00:00
|
|
|
yield
|
|
|
|
finally:
|
2024-01-29 15:47:56 +00:00
|
|
|
torch.nn.Module.__init__ = original_init
|
|
|
|
torch.nn.Module.to = original_to
|
2024-01-29 04:39:00 +00:00
|
|
|
|
2024-01-29 17:28:38 +00:00
|
|
|
start = time.perf_counter()
|
|
|
|
module_list = set(module_list)
|
|
|
|
|
|
|
|
for module in module_list:
|
|
|
|
module.cpu()
|
2024-01-29 04:39:00 +00:00
|
|
|
|
2024-01-29 15:47:56 +00:00
|
|
|
model_management.soft_empty_cache()
|
2024-01-29 17:28:38 +00:00
|
|
|
end = time.perf_counter()
|
2024-01-29 04:39:00 +00:00
|
|
|
|
2024-01-29 17:32:51 +00:00
|
|
|
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
|
2024-01-29 15:47:56 +00:00
|
|
|
return
|