rework memory management for extras

now face post-processing uses gpu
close #312
This commit is contained in:
lllyasviel 2024-02-25 20:37:14 -08:00
parent a252bbcf16
commit 6287c73d98
10 changed files with 30 additions and 21 deletions

View File

@ -1,7 +1,8 @@
import os
from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules_forge.forge_util import prepare_free_memory
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
import sd_hijack_autoencoder # noqa: F401

View File

@ -5,7 +5,8 @@ import torch
from PIL import Image
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules_forge.forge_util import prepare_free_memory
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"

View File

@ -2,8 +2,9 @@ import os
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
from modules_forge.forge_util import prepare_free_memory
class UpscalerDAT(Upscaler):

View File

@ -50,10 +50,10 @@ def enable_tf32():
cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = model_management.get_torch_device()
device_interrogate: torch.device = cpu # not used
device_gfpgan: torch.device = cpu
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
device_codeformer: torch.device = cpu
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
dtype: torch.dtype = model_management.unet_dtype()
dtype_vae: torch.dtype = model_management.vae_dtype()
dtype_unet: torch.dtype = model_management.unet_dtype()

View File

@ -1,7 +1,8 @@
from modules import modelloader, devices, errors
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
from modules_forge.forge_util import prepare_free_memory
class UpscalerESRGAN(Upscaler):

View File

@ -10,6 +10,7 @@ import numpy as np
import torch
from modules import devices, errors, face_restoration, shared
from modules_forge.forge_util import prepare_free_memory
if TYPE_CHECKING:
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -153,6 +154,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
return np_image
try:
prepare_free_memory()
self.send_model_to(self.get_device())
return restore_with_face_helper(np_image, self.face_helper, restore_face)
finally:

View File

@ -3,8 +3,9 @@ import sys
from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
from modules_forge.forge_util import prepare_free_memory
class UpscalerHAT(Upscaler):

View File

@ -2,8 +2,9 @@ import os
from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
from modules_forge.forge_util import prepare_free_memory
class UpscalerRealESRGAN(Upscaler):

View File

@ -6,17 +6,6 @@ from PIL import Image
import modules.shared
from modules import modelloader, shared
from ldm_patched.modules import model_management
def prepare_free_memory(aggressive=False):
if aggressive:
model_management.unload_all_models()
print('Upscale script freed all memory.')
return
model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device())
print('Upscale script freed memory successfully.')
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)

View File

@ -9,6 +9,18 @@ import cv2
from ldm_patched.modules import model_management
def prepare_free_memory(aggressive=False):
if aggressive:
model_management.unload_all_models()
print('Cleanup all memory.')
return
model_management.free_memory(memory_required=model_management.minimum_inference_memory(),
device=model_management.get_torch_device())
print('Cleanup minimal inference memory.')
return
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2: