rework memory management for extras
now face post-processing uses gpu close #312
This commit is contained in:
parent
a252bbcf16
commit
6287c73d98
@ -1,7 +1,8 @@
|
||||
import os
|
||||
|
||||
from modules.modelloader import load_file_from_url
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
from ldsr_model_arch import LDSR
|
||||
from modules import shared, script_callbacks, errors
|
||||
import sd_hijack_autoencoder # noqa: F401
|
||||
|
@ -5,7 +5,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||
|
||||
|
@ -2,8 +2,9 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerDAT(Upscaler):
|
||||
|
@ -50,10 +50,10 @@ def enable_tf32():
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = model_management.get_torch_device()
|
||||
device_interrogate: torch.device = cpu # not used
|
||||
device_gfpgan: torch.device = cpu
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
|
||||
device_codeformer: torch.device = cpu
|
||||
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = model_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
||||
|
@ -1,7 +1,8 @@
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
@ -153,6 +154,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
return np_image
|
||||
|
||||
try:
|
||||
prepare_free_memory()
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
|
@ -3,8 +3,9 @@ import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerHAT(Upscaler):
|
||||
|
@ -2,8 +2,9 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
|
@ -6,17 +6,6 @@ from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def prepare_free_memory(aggressive=False):
|
||||
if aggressive:
|
||||
model_management.unload_all_models()
|
||||
print('Upscale script freed all memory.')
|
||||
return
|
||||
|
||||
model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device())
|
||||
print('Upscale script freed memory successfully.')
|
||||
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
@ -9,6 +9,18 @@ import cv2
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def prepare_free_memory(aggressive=False):
|
||||
if aggressive:
|
||||
model_management.unload_all_models()
|
||||
print('Cleanup all memory.')
|
||||
return
|
||||
|
||||
model_management.free_memory(memory_required=model_management.minimum_inference_memory(),
|
||||
device=model_management.get_torch_device())
|
||||
print('Cleanup minimal inference memory.')
|
||||
return
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
|
Loading…
Reference in New Issue
Block a user