Add build-in extension "NeverOOM"

see also discussions
This commit is contained in:
lllyasviel 2024-02-24 19:09:06 -08:00
parent 50229a05c1
commit 437c348926
5 changed files with 68 additions and 6 deletions

View File

@ -0,0 +1,47 @@
import gradio as gr
from modules import scripts
from ldm_patched.modules import model_management
class NeverOOMForForge(scripts.Script):
sorting_priority = 18
def __init__(self):
self.previous_unet_enabled = False
self.original_vram_state = model_management.vram_state
def title(self):
return "Never OOM Integrated"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, *args, **kwargs):
with gr.Accordion(open=False, label=self.title()):
unet_enabled = gr.Checkbox(label='Enabled for UNet (always maximize offload)', value=False)
vae_enabled = gr.Checkbox(label='Enabled for VAE (always tiled)', value=False)
return unet_enabled, vae_enabled
def process(self, p, *script_args, **kwargs):
unet_enabled, vae_enabled = script_args
if unet_enabled:
print('NeverOOM Enabled for UNet (always maximize offload)')
if vae_enabled:
print('NeverOOM Enabled for VAE (always tiled)')
model_management.VAE_ALWAYS_TILED = vae_enabled
if self.previous_unet_enabled != unet_enabled:
model_management.unload_all_models()
if unet_enabled:
self.original_vram_state = model_management.vram_state
model_management.vram_state = model_management.VRAMState.NO_VRAM
else:
model_management.vram_state = self.original_vram_state
print(f'VARM State Changed To {model_management.vram_state.name}')
self.previous_unet_enabled = unet_enabled
return

View File

@ -204,6 +204,9 @@ elif args.vae_in_fp32:
VAE_DTYPE = torch.float32
VAE_ALWAYS_TILED = False
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)

View File

@ -208,7 +208,7 @@ class VAE:
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = ldm_patched.modules.utils.ProgressBar(steps)
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode')
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
@ -222,7 +222,7 @@ class VAE:
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = ldm_patched.modules.utils.ProgressBar(steps)
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode')
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -232,6 +232,9 @@ class VAE:
return samples
def decode(self, samples_in):
if model_management.VAE_ALWAYS_TILED:
return self.decode_tiled(samples_in).to(self.output_device)
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
@ -256,6 +259,9 @@ class VAE:
return output.movedim(1,-1)
def encode(self, pixel_samples):
if model_management.VAE_ALWAYS_TILED:
return self.encode_tiled(pixel_samples)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)

View File

@ -1,5 +1,5 @@
# Taken from https://github.com/comfyanonymous/ComfyUI
# This file is only for reference, and not used in the backend or runtime.
# 1st edit https://github.com/comfyanonymous/ComfyUI
# 2nd edit by Forge
import torch
@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle
import safetensors.torch
import numpy as np
from PIL import Image
from tqdm import tqdm
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
def __init__(self, total, title=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.tqdm = tqdm(total=total, desc=title)
def update_absolute(self, value, total=None, preview=None):
if total is not None:
self.total = total
if value > self.total:
value = self.total
inc = value - self.current
self.tqdm.update(inc)
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
if self.current >= self.total:
self.tqdm.close()
def update(self, value):
self.update_absolute(self.current + value)

View File

@ -1 +1 @@
version = '0.0.15v1.8.0rc'
version = '0.0.16v1.8.0rc'