diff --git a/extensions-builtin/sd_forge_neveroom/scripts/forge_never_oom.py b/extensions-builtin/sd_forge_neveroom/scripts/forge_never_oom.py new file mode 100644 index 00000000..374e8123 --- /dev/null +++ b/extensions-builtin/sd_forge_neveroom/scripts/forge_never_oom.py @@ -0,0 +1,47 @@ +import gradio as gr + +from modules import scripts +from ldm_patched.modules import model_management + + +class NeverOOMForForge(scripts.Script): + sorting_priority = 18 + + def __init__(self): + self.previous_unet_enabled = False + self.original_vram_state = model_management.vram_state + + def title(self): + return "Never OOM Integrated" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + unet_enabled = gr.Checkbox(label='Enabled for UNet (always maximize offload)', value=False) + vae_enabled = gr.Checkbox(label='Enabled for VAE (always tiled)', value=False) + return unet_enabled, vae_enabled + + def process(self, p, *script_args, **kwargs): + unet_enabled, vae_enabled = script_args + + if unet_enabled: + print('NeverOOM Enabled for UNet (always maximize offload)') + + if vae_enabled: + print('NeverOOM Enabled for VAE (always tiled)') + + model_management.VAE_ALWAYS_TILED = vae_enabled + + if self.previous_unet_enabled != unet_enabled: + model_management.unload_all_models() + if unet_enabled: + self.original_vram_state = model_management.vram_state + model_management.vram_state = model_management.VRAMState.NO_VRAM + else: + model_management.vram_state = self.original_vram_state + print(f'VARM State Changed To {model_management.vram_state.name}') + self.previous_unet_enabled = unet_enabled + + return diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 21ba1652..b7febda7 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -204,6 +204,9 @@ elif args.vae_in_fp32: VAE_DTYPE = torch.float32 +VAE_ALWAYS_TILED = False + + if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 254fa1e3..ed232b5a 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -208,7 +208,7 @@ class VAE: steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = ldm_patched.modules.utils.ProgressBar(steps) + pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled decode') decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( @@ -222,7 +222,7 @@ class VAE: steps = pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += pixel_samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = ldm_patched.modules.utils.ProgressBar(steps) + pbar = ldm_patched.modules.utils.ProgressBar(steps, title='VAE tiled encode') encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() samples = ldm_patched.modules.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -232,6 +232,9 @@ class VAE: return samples def decode(self, samples_in): + if model_management.VAE_ALWAYS_TILED: + return self.decode_tiled(samples_in).to(self.output_device) + try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used) @@ -256,6 +259,9 @@ class VAE: return output.movedim(1,-1) def encode(self, pixel_samples): + if model_management.VAE_ALWAYS_TILED: + return self.encode_tiled(pixel_samples) + pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) diff --git a/ldm_patched/modules/utils.py b/ldm_patched/modules/utils.py index 57bbb578..5d00dcc0 100644 --- a/ldm_patched/modules/utils.py +++ b/ldm_patched/modules/utils.py @@ -1,5 +1,5 @@ -# Taken from https://github.com/comfyanonymous/ComfyUI -# This file is only for reference, and not used in the backend or runtime. +# 1st edit https://github.com/comfyanonymous/ComfyUI +# 2nd edit by Forge import torch @@ -9,6 +9,7 @@ import ldm_patched.modules.checkpoint_pickle import safetensors.torch import numpy as np from PIL import Image +from tqdm import tqdm def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -448,20 +449,25 @@ def set_progress_bar_global_hook(function): PROGRESS_BAR_HOOK = function class ProgressBar: - def __init__(self, total): + def __init__(self, total, title=None): global PROGRESS_BAR_HOOK self.total = total self.current = 0 self.hook = PROGRESS_BAR_HOOK + self.tqdm = tqdm(total=total, desc=title) def update_absolute(self, value, total=None, preview=None): if total is not None: self.total = total if value > self.total: value = self.total + inc = value - self.current + self.tqdm.update(inc) self.current = value if self.hook is not None: self.hook(self.current, self.total, preview) + if self.current >= self.total: + self.tqdm.close() def update(self, value): self.update_absolute(self.current + value) diff --git a/modules_forge/forge_version.py b/modules_forge/forge_version.py index 97f1cf1a..11772f7d 100644 --- a/modules_forge/forge_version.py +++ b/modules_forge/forge_version.py @@ -1 +1 @@ -version = '0.0.15v1.8.0rc' +version = '0.0.16v1.8.0rc'