Add build-in extension "NeverOOM"
see also discussions
This commit is contained in:
parent
50229a05c1
commit
437c348926
@ -0,0 +1,47 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
class NeverOOMForForge(scripts.Script):
|
||||
sorting_priority = 18
|
||||
|
||||
def __init__(self):
|
||||
self.previous_unet_enabled = False
|
||||
self.original_vram_state = model_management.vram_state
|
||||
|
||||
def title(self):
|
||||
return "Never OOM Integrated"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, *args, **kwargs):
|
||||
with gr.Accordion(open=False, label=self.title()):
|
||||
unet_enabled = gr.Checkbox(label='Enabled for UNet (always maximize offload)', value=False)
|
||||
vae_enabled = gr.Checkbox(label='Enabled for VAE (always tiled)', value=False)
|
||||
return unet_enabled, vae_enabled
|
||||
|
||||
def process(self, p, *script_args, **kwargs):
|
||||
unet_enabled, vae_enabled = script_args
|
||||
|
||||
if unet_enabled:
|
||||
print('NeverOOM Enabled for UNet (always maximize offload)')
|
||||
|
||||
if vae_enabled:
|
||||
print('NeverOOM Enabled for VAE (always tiled)')
|
||||
|
||||
model_management.VAE_ALWAYS_TILED = vae_enabled
|
||||
|
||||
if self.previous_unet_enabled != unet_enabled:
|
||||
model_management.unload_all_models()
|
||||
if unet_enabled:
|
||||
self.original_vram_state = model_management.vram_state
|
||||
model_management.vram_state = model_management.VRAMState.NO_VRAM
|
||||
else:
|
||||
model_management.vram_state = self.original_vram_state
|
||||
print(f'VARM State Changed To {model_management.vram_state.name}')
|
||||
self.previous_unet_enabled = unet_enabled
|
||||
|
||||
return
|
@ -204,6 +204,9 @@ elif args.vae_in_fp32:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
|
||||
VAE_ALWAYS_TILED = False
|
||||
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
|
@ -208,7 +208,7 @@ class VAE:
|
||||
steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode')
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
@ -222,7 +222,7 @@ class VAE:
|
||||
steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps)
|
||||
pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode')
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -232,6 +232,9 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
if model_management.VAE_ALWAYS_TILED:
|
||||
return self.decode_tiled(samples_in).to(self.output_device)
|
||||
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
@ -256,6 +259,9 @@ class VAE:
|
||||
return output.movedim(1,-1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
if model_management.VAE_ALWAYS_TILED:
|
||||
return self.encode_tiled(pixel_samples)
|
||||
|
||||
pixel_samples = pixel_samples.movedim(-1,1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Taken from https://github.com/comfyanonymous/ComfyUI
|
||||
# This file is only for reference, and not used in the backend or runtime.
|
||||
# 1st edit https://github.com/comfyanonymous/ComfyUI
|
||||
# 2nd edit by Forge
|
||||
|
||||
|
||||
import torch
|
||||
@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function):
|
||||
PROGRESS_BAR_HOOK = function
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total):
|
||||
def __init__(self, total, title=None):
|
||||
global PROGRESS_BAR_HOOK
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.hook = PROGRESS_BAR_HOOK
|
||||
self.tqdm = tqdm(total=total, desc=title)
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
inc = value - self.current
|
||||
self.tqdm.update(inc)
|
||||
self.current = value
|
||||
if self.hook is not None:
|
||||
self.hook(self.current, self.total, preview)
|
||||
if self.current >= self.total:
|
||||
self.tqdm.close()
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
@ -1 +1 @@
|
||||
version = '0.0.15v1.8.0rc'
|
||||
version = '0.0.16v1.8.0rc'
|
||||
|
Loading…
Reference in New Issue
Block a user