From 6287c73d988fd48973cef03119ceef08bd4739a3 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 25 Feb 2024 20:37:14 -0800 Subject: [PATCH] rework memory management for extras now face post-processing uses gpu close #312 --- extensions-builtin/LDSR/scripts/ldsr_model.py | 3 ++- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- modules/dat_model.py | 3 ++- modules/devices.py | 8 ++++---- modules/esrgan_model.py | 3 ++- modules/face_restoration_utils.py | 2 ++ modules/hat_model.py | 3 ++- modules/realesrgan_model.py | 3 ++- modules/upscaler.py | 11 ----------- modules_forge/forge_util.py | 12 ++++++++++++ 10 files changed, 30 insertions(+), 21 deletions(-) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index a8b11d3a..c7ae4d4b 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,7 +1,8 @@ import os from modules.modelloader import load_file_from_url -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData +from modules_forge.forge_util import prepare_free_memory from ldsr_model_arch import LDSR from modules import shared, script_callbacks, errors import sd_hijack_autoencoder # noqa: F401 diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index d96e27c8..3b51ee86 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -5,7 +5,8 @@ import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData +from modules_forge.forge_util import prepare_free_memory SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" diff --git a/modules/dat_model.py b/modules/dat_model.py index 24c4c968..c2caaafd 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -2,8 +2,9 @@ import os from modules import modelloader, errors from modules.shared import cmd_opts, opts -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model +from modules_forge.forge_util import prepare_free_memory class UpscalerDAT(Upscaler): diff --git a/modules/devices.py b/modules/devices.py index 7c09d1f4..08d0d706 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -50,10 +50,10 @@ def enable_tf32(): cpu: torch.device = torch.device("cpu") fp8: bool = False device: torch.device = model_management.get_torch_device() -device_interrogate: torch.device = cpu # not used -device_gfpgan: torch.device = cpu -device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way -device_codeformer: torch.device = cpu +device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now +device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system +device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system +device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system dtype: torch.dtype = model_management.unet_dtype() dtype_vae: torch.dtype = model_management.vae_dtype() dtype_unet: torch.dtype = model_management.unet_dtype() diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 4ade7fce..2c4505a7 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,7 +1,8 @@ from modules import modelloader, devices, errors from modules.shared import opts -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model +from modules_forge.forge_util import prepare_free_memory class UpscalerESRGAN(Upscaler): diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index 1cbac236..1ba01734 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -10,6 +10,7 @@ import numpy as np import torch from modules import devices, errors, face_restoration, shared +from modules_forge.forge_util import prepare_free_memory if TYPE_CHECKING: from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -153,6 +154,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration): return np_image try: + prepare_free_memory() self.send_model_to(self.get_device()) return restore_with_face_helper(np_image, self.face_helper, restore_face) finally: diff --git a/modules/hat_model.py b/modules/hat_model.py index e31dee21..fe6008b1 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -3,8 +3,9 @@ import sys from modules import modelloader, devices from modules.shared import opts -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model +from modules_forge.forge_util import prepare_free_memory class UpscalerHAT(Upscaler): diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index e9314510..27425bd9 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,8 +2,9 @@ import os from modules import modelloader, errors from modules.shared import cmd_opts, opts -from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory +from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model +from modules_forge.forge_util import prepare_free_memory class UpscalerRealESRGAN(Upscaler): diff --git a/modules/upscaler.py b/modules/upscaler.py index b89b86ef..0e38d52f 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -6,17 +6,6 @@ from PIL import Image import modules.shared from modules import modelloader, shared -from ldm_patched.modules import model_management - - -def prepare_free_memory(aggressive=False): - if aggressive: - model_management.unload_all_models() - print('Upscale script freed all memory.') - return - - model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device()) - print('Upscale script freed memory successfully.') LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index 6ec2134c..4f24d85d 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -9,6 +9,18 @@ import cv2 from ldm_patched.modules import model_management +def prepare_free_memory(aggressive=False): + if aggressive: + model_management.unload_all_models() + print('Cleanup all memory.') + return + + model_management.free_memory(memory_required=model_management.minimum_inference_memory(), + device=model_management.get_torch_device()) + print('Cleanup minimal inference memory.') + return + + def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: