my-sd/modules_forge/ops.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

47 lines
1.2 KiB
Python
Raw Normal View History

2024-01-29 17:28:38 +00:00
import time
import torch
import contextlib
2024-01-29 15:47:56 +00:00
from ldm_patched.modules import model_management
from ldm_patched.modules.ops import use_patched_ops
@contextlib.contextmanager
2024-01-29 15:47:56 +00:00
def automatic_memory_management():
model_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=model_management.get_torch_device()
)
module_list = []
2024-01-29 15:47:56 +00:00
original_init = torch.nn.Module.__init__
original_to = torch.nn.Module.to
def patched_init(self, *args, **kwargs):
module_list.append(self)
2024-01-29 15:47:56 +00:00
return original_init(self, *args, **kwargs)
def patched_to(self, *args, **kwargs):
module_list.append(self)
return original_to(self, *args, **kwargs)
try:
torch.nn.Module.__init__ = patched_init
2024-01-29 15:47:56 +00:00
torch.nn.Module.to = patched_to
yield
finally:
2024-01-29 15:47:56 +00:00
torch.nn.Module.__init__ = original_init
torch.nn.Module.to = original_to
2024-01-29 17:28:38 +00:00
start = time.perf_counter()
module_list = set(module_list)
for module in module_list:
module.cpu()
2024-01-29 15:47:56 +00:00
model_management.soft_empty_cache()
2024-01-29 17:28:38 +00:00
end = time.perf_counter()
2024-01-29 17:32:51 +00:00
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
2024-01-29 15:47:56 +00:00
return