diff --git a/README.md b/README.md index 77a04f11..6cb59742 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,8 @@ Without any cmd flag, Forge can run SDXL with 4GB vram and SD1.5 with 2GB vram. 3. `--cuda-malloc` (This flag will make things **faster** but more risky). This will ask pytorch to use *cudaMallocAsync* for tensor malloc. On some profilers I can observe performance gain at millisecond level, but the real speed up on most my devices are often unnoticed (about or less than 0.1 second per image). This cannot be set as default because many users reported issues that the async malloc will crash the program. Users need to enable this cmd flag at their own risk. +4. `--cuda-stream` (This flag will make things **faster** but more risky). This will use pytorch CUDA streams (a special type of thread on GPU) to move models and compute tensors simultaneously. This can almost eliminate all model moving time, and speed up SDXL on 30XX/40XX devices with small VRAM (eg, RTX 4050 6GB, RTX 3060 Laptop 6GB, etc) by about 15\% to 25\%. However, this unfortunately cannot be set as default because I observe higher possibility of pure black images (Nan outputs) on 2060, and higher chance of OOM on 1080 and 2060. When the resolution is large, there is a chance that the computation time of one single attention layer is longer than the time for moving entire model to GPU. When that happens, the next attention layer will OOM since the GPU is filled with the entire model, and no remaining space is available for computing another attention layer. Most overhead detecting methods are not robust enough to be reliable on old devices (in my tests). Users need to enable this cmd flag at their own risk. + If you really want to play with cmd flags, you can additionally control the GPU with: (extreme VRAM cases) diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index e4aac7bc..8d39f757 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -116,6 +116,7 @@ parser.add_argument("--disable-server-info", action="store_true") parser.add_argument("--multi-user", action="store_true") parser.add_argument("--cuda-malloc", action="store_true") +parser.add_argument("--cuda-stream", action="store_true") parser.add_argument("--pin-shared-memory", action="store_true") if ldm_patched.modules.options.args_parsing: diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 296941b5..6192d7ae 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -14,7 +14,7 @@ import ldm_patched.modules.ops import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter -from ldm_patched.modules.ops import main_thread_worker +from ldm_patched.modules.ops import main_stream_worker def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -306,7 +306,7 @@ class ControlLoraOps: def forward(self, input): weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): if self.up is not None: return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: @@ -347,7 +347,7 @@ class ControlLoraOps: def forward(self, input): weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): if self.up is not None: return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: diff --git a/ldm_patched/modules/model_base.py b/ldm_patched/modules/model_base.py index 614b0d47..e847648c 100644 --- a/ldm_patched/modules/model_base.py +++ b/ldm_patched/modules/model_base.py @@ -215,9 +215,9 @@ class BaseModel(torch.nn.Module): dtype_size = ldm_patched.modules.model_management.dtype_size(dtype) if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention(): - scaler = 1.25 + scaler = 1.28 else: - scaler = 1.75 + scaler = 1.65 if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32": dtype_size = 4 diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index ee22d406..21ba1652 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -6,6 +6,7 @@ import time import psutil from enum import Enum from ldm_patched.modules.args_parser import args +from modules_forge import stream import ldm_patched.modules.utils import torch import sys @@ -277,6 +278,8 @@ if 'rtx' in torch_device_name.lower(): print('Hint: your device supports --pin-shared-memory for potential speed improvements.') if not args.cuda_malloc: print('Hint: your device supports --cuda-malloc for potential speed improvements.') + if not args.cuda_stream: + print('Hint: your device supports --cuda-stream for potential speed improvements.') print("VAE dtype:", VAE_DTYPE) @@ -326,7 +329,8 @@ class LoadedModel: raise e if not disable_async_load: - print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024)) + flag = 'ASYNC' if stream.using_stream else 'SYNC' + print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024)) real_async_memory = 0 mem_counter = 0 for m in self.real_model.modules(): @@ -345,9 +349,9 @@ class LoadedModel: elif hasattr(m, "weight"): m.to(self.device) mem_counter += module_size(m) - print("[Memory Management] Async Loader Disabled for ", m) - print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024)) - print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) + print(f"[Memory Management] {flag} Loader Disabled for ", m) + print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024)) + print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) self.model_accelerated = True @@ -372,7 +376,7 @@ class LoadedModel: self.model.model_patches_to(self.model.offload_device) def __eq__(self, other): - return self.model is other.model and self.memory_required == other.memory_required + return self.model is other.model # and self.memory_required == other.memory_required def minimum_inference_memory(): return (1024 * 1024 * 1024) @@ -383,7 +387,8 @@ def unload_model_clones(model): if model.is_clone(current_loaded_models[i].model): to_unload = [i] + to_unload - print(f"Reuse {len(to_unload)} loaded models") + if len(to_unload) > 0: + print(f"Reuse {len(to_unload)} loaded models") for i in to_unload: current_loaded_models.pop(i).model_unload(avoid_model_moving=True) @@ -414,9 +419,7 @@ def load_models_gpu(models, memory_required=0): global vram_state execution_start_time = time.perf_counter() - - inference_memory = minimum_inference_memory() - extra_mem = max(inference_memory, memory_required) + extra_mem = max(minimum_inference_memory(), memory_required) models_to_load = [] models_already_loaded = [] @@ -439,7 +442,7 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) moving_time = time.perf_counter() - execution_start_time - if moving_time > 0.01: + if moving_time > 0.1: print(f'Memory cleanup has taken {moving_time:.2f} seconds') return @@ -466,25 +469,25 @@ def load_models_gpu(models, memory_required=0): async_kept_memory = -1 if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): - model_size = loaded_model.model_memory_required(torch_dev) + model_memory = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - estimated_memory_remaining = current_free_mem - model_size - extra_mem + minimal_inference_memory = minimum_inference_memory() + estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory - print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024)) - print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024)) - print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024)) - print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024)) + print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024)) + print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024)) + print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024)) + print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024)) - if estimated_memory_remaining < 0: + if estimated_remaining_memory < 0: vram_set_state = VRAMState.LOW_VRAM - async_overhead_memory = 1024 * 1024 * 1024 - async_kept_memory = current_free_mem - extra_mem - async_overhead_memory + async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3 async_kept_memory = int(max(0, async_kept_memory)) if vram_set_state == VRAMState.NO_VRAM: async_kept_memory = 0 - cur_loaded_model = loaded_model.model_load(async_kept_memory) + loaded_model.model_load(async_kept_memory) current_loaded_models.insert(0, loaded_model) moving_time = time.perf_counter() - execution_start_time diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index cd83fc5e..beb8f265 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -10,7 +10,7 @@ from modules_forge import stream # https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files -gc = {} +stash = {} @contextlib.contextmanager @@ -31,26 +31,25 @@ def use_patched_ops(operations): def cast_bias_weight(s, input): - context = contextlib.nullcontext - signal = None + weight, bias, signal = None, None, None + non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) if stream.using_stream: - context = stream.stream_context() - - with context(stream.mover_stream): - bias = None - non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) + with stream.stream_context()(stream.mover_stream): + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + signal = stream.mover_stream.record_event() + else: if s.bias is not None: bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) - if stream.using_stream: - signal = stream.mover_stream.record_event() return weight, bias, signal @contextlib.contextmanager -def main_thread_worker(weight, bias, signal): +def main_stream_worker(weight, bias, signal): if not stream.using_stream or signal is None: yield return @@ -59,40 +58,25 @@ def main_thread_worker(weight, bias, signal): stream.current_stream.wait_event(signal) yield finished_signal = stream.current_stream.record_event() - size = weight.element_size() * weight.nelement() - if bias is not None: - size += bias.element_size() * bias.nelement() - gc[id(finished_signal)] = (weight, bias, finished_signal, size) - - overhead = sum([l for k, (w, b, s, l) in gc.items()]) - - if overhead > 512 * 1024 * 1024: - stream.mover_stream.synchronize() - stream.current_stream.synchronize() + stash[id(finished_signal)] = (weight, bias, finished_signal) garbage = [] - for k, (w, b, s, l) in gc.items(): + for k, (w, b, s) in stash.items(): if s.query(): garbage.append(k) for k in garbage: - del gc[k] + del stash[k] return def cleanup_cache(): - global gc + if not stream.using_stream: + return - if stream.current_stream is not None: - with stream.stream_context()(stream.current_stream): - for k, (w, b, s, l) in gc.items(): - stream.current_stream.wait_event(s) - stream.current_stream.synchronize() - - gc.clear() - - if stream.mover_stream is not None: - stream.mover_stream.synchronize() + stream.current_stream.synchronize() + stream.mover_stream.synchronize() + stash.clear() return @@ -104,7 +88,7 @@ class disable_weight_init: def forward_ldm_patched_cast_weights(self, input): weight, bias, signal = cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): @@ -120,7 +104,7 @@ class disable_weight_init: def forward_ldm_patched_cast_weights(self, input): weight, bias, signal = cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): @@ -136,7 +120,7 @@ class disable_weight_init: def forward_ldm_patched_cast_weights(self, input): weight, bias, signal = cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): @@ -152,7 +136,7 @@ class disable_weight_init: def forward_ldm_patched_cast_weights(self, input): weight, bias, signal = cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): @@ -169,7 +153,7 @@ class disable_weight_init: def forward_ldm_patched_cast_weights(self, input): weight, bias, signal = cast_bias_weight(self, input) - with main_thread_worker(weight, bias, signal): + with main_stream_worker(weight, bias, signal): return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): diff --git a/modules/shared_options.py b/modules/shared_options.py index 847d1359..72cb79a9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -216,7 +216,6 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), - "use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 605b9096..2a37bcb8 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -59,6 +59,9 @@ def initialize_forge(): import modules_forge.patch_basic modules_forge.patch_basic.patch_all_basics() + from modules_forge import stream + print('CUDA Stream Activated: ', stream.using_stream) + from modules_forge.shared import diffusers_dir if 'TRANSFORMERS_CACHE' not in os.environ: diff --git a/modules_forge/stream.py b/modules_forge/stream.py index d019047d..93ddadeb 100644 --- a/modules_forge/stream.py +++ b/modules_forge/stream.py @@ -2,7 +2,7 @@ import torch -from modules import shared +from ldm_patched.modules import args_parser from ldm_patched.modules import model_management @@ -56,14 +56,12 @@ def get_new_stream(): return None -if shared.opts.use_non_streamlined_lowvram: - current_stream = None - mover_stream = None - using_stream = False -else: +current_stream = None +mover_stream = None +using_stream = False + +if args_parser.args.cuda_stream: current_stream = get_current_stream() mover_stream = get_new_stream() using_stream = current_stream is not None and mover_stream is not None -if not using_stream: - print('Stream is not used.')