Add optimization --cuda-stream

See also the readme for more details
This commit is contained in:
lllyasviel 2024-02-24 14:00:48 -08:00 committed by GitHub
parent 0f09d98814
commit 434ca2169f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 63 additions and 73 deletions

View File

@ -88,6 +88,8 @@ Without any cmd flag, Forge can run SDXL with 4GB vram and SD1.5 with 2GB vram.
3. `--cuda-malloc` (This flag will make things **faster** but more risky). This will ask pytorch to use *cudaMallocAsync* for tensor malloc. On some profilers I can observe performance gain at millisecond level, but the real speed up on most my devices are often unnoticed (about or less than 0.1 second per image). This cannot be set as default because many users reported issues that the async malloc will crash the program. Users need to enable this cmd flag at their own risk.
4. `--cuda-stream` (This flag will make things **faster** but more risky). This will use pytorch CUDA streams (a special type of thread on GPU) to move models and compute tensors simultaneously. This can almost eliminate all model moving time, and speed up SDXL on 30XX/40XX devices with small VRAM (eg, RTX 4050 6GB, RTX 3060 Laptop 6GB, etc) by about 15\% to 25\%. However, this unfortunately cannot be set as default because I observe higher possibility of pure black images (Nan outputs) on 2060, and higher chance of OOM on 1080 and 2060. When the resolution is large, there is a chance that the computation time of one single attention layer is longer than the time for moving entire model to GPU. When that happens, the next attention layer will OOM since the GPU is filled with the entire model, and no remaining space is available for computing another attention layer. Most overhead detecting methods are not robust enough to be reliable on old devices (in my tests). Users need to enable this cmd flag at their own risk.
If you really want to play with cmd flags, you can additionally control the GPU with:
(extreme VRAM cases)

View File

@ -116,6 +116,7 @@ parser.add_argument("--disable-server-info", action="store_true")
parser.add_argument("--multi-user", action="store_true")
parser.add_argument("--cuda-malloc", action="store_true")
parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true")
if ldm_patched.modules.options.args_parsing:

View File

@ -14,7 +14,7 @@ import ldm_patched.modules.ops
import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
from ldm_patched.modules.ops import main_thread_worker
from ldm_patched.modules.ops import main_stream_worker
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -306,7 +306,7 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
@ -347,7 +347,7 @@ class ControlLoraOps:
def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:

View File

@ -215,9 +215,9 @@ class BaseModel(torch.nn.Module):
dtype_size = ldm_patched.modules.model_management.dtype_size(dtype)
if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention():
scaler = 1.25
scaler = 1.28
else:
scaler = 1.75
scaler = 1.65
if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32":
dtype_size = 4

View File

@ -6,6 +6,7 @@ import time
import psutil
from enum import Enum
from ldm_patched.modules.args_parser import args
from modules_forge import stream
import ldm_patched.modules.utils
import torch
import sys
@ -277,6 +278,8 @@ if 'rtx' in torch_device_name.lower():
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
if not args.cuda_malloc:
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
if not args.cuda_stream:
print('Hint: your device supports --cuda-stream for potential speed improvements.')
print("VAE dtype:", VAE_DTYPE)
@ -326,7 +329,8 @@ class LoadedModel:
raise e
if not disable_async_load:
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
flag = 'ASYNC' if stream.using_stream else 'SYNC'
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
real_async_memory = 0
mem_counter = 0
for m in self.real_model.modules():
@ -345,9 +349,9 @@ class LoadedModel:
elif hasattr(m, "weight"):
m.to(self.device)
mem_counter += module_size(m)
print("[Memory Management] Async Loader Disabled for ", m)
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024))
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
print(f"[Memory Management] {flag} Loader Disabled for ", m)
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
self.model_accelerated = True
@ -372,7 +376,7 @@ class LoadedModel:
self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other):
return self.model is other.model and self.memory_required == other.memory_required
return self.model is other.model # and self.memory_required == other.memory_required
def minimum_inference_memory():
return (1024 * 1024 * 1024)
@ -383,7 +387,8 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
print(f"Reuse {len(to_unload)} loaded models")
if len(to_unload) > 0:
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload:
current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
@ -414,9 +419,7 @@ def load_models_gpu(models, memory_required=0):
global vram_state
execution_start_time = time.perf_counter()
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
extra_mem = max(minimum_inference_memory(), memory_required)
models_to_load = []
models_already_loaded = []
@ -439,7 +442,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.01:
if moving_time > 0.1:
print(f'Memory cleanup has taken {moving_time:.2f} seconds')
return
@ -466,25 +469,25 @@ def load_models_gpu(models, memory_required=0):
async_kept_memory = -1
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
model_memory = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
estimated_memory_remaining = current_free_mem - model_size - extra_mem
minimal_inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024))
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024))
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024))
print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024))
print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024))
print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024))
if estimated_memory_remaining < 0:
if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM
async_overhead_memory = 1024 * 1024 * 1024
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3
async_kept_memory = int(max(0, async_kept_memory))
if vram_set_state == VRAMState.NO_VRAM:
async_kept_memory = 0
cur_loaded_model = loaded_model.model_load(async_kept_memory)
loaded_model.model_load(async_kept_memory)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time

View File

@ -10,7 +10,7 @@ from modules_forge import stream
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
gc = {}
stash = {}
@contextlib.contextmanager
@ -31,26 +31,25 @@ def use_patched_ops(operations):
def cast_bias_weight(s, input):
context = contextlib.nullcontext
signal = None
weight, bias, signal = None, None, None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if stream.using_stream:
context = stream.stream_context()
with context(stream.mover_stream):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
with stream.stream_context()(stream.mover_stream):
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
signal = stream.mover_stream.record_event()
else:
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if stream.using_stream:
signal = stream.mover_stream.record_event()
return weight, bias, signal
@contextlib.contextmanager
def main_thread_worker(weight, bias, signal):
def main_stream_worker(weight, bias, signal):
if not stream.using_stream or signal is None:
yield
return
@ -59,40 +58,25 @@ def main_thread_worker(weight, bias, signal):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
size = weight.element_size() * weight.nelement()
if bias is not None:
size += bias.element_size() * bias.nelement()
gc[id(finished_signal)] = (weight, bias, finished_signal, size)
overhead = sum([l for k, (w, b, s, l) in gc.items()])
if overhead > 512 * 1024 * 1024:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
stash[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s, l) in gc.items():
for k, (w, b, s) in stash.items():
if s.query():
garbage.append(k)
for k in garbage:
del gc[k]
del stash[k]
return
def cleanup_cache():
global gc
if not stream.using_stream:
return
if stream.current_stream is not None:
with stream.stream_context()(stream.current_stream):
for k, (w, b, s, l) in gc.items():
stream.current_stream.wait_event(s)
stream.current_stream.synchronize()
gc.clear()
if stream.mover_stream is not None:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
stream.mover_stream.synchronize()
stash.clear()
return
@ -104,7 +88,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
@ -120,7 +104,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@ -136,7 +120,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@ -152,7 +136,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@ -169,7 +153,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):

View File

@ -216,7 +216,6 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {

View File

@ -59,6 +59,9 @@ def initialize_forge():
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
from modules_forge import stream
print('CUDA Stream Activated: ', stream.using_stream)
from modules_forge.shared import diffusers_dir
if 'TRANSFORMERS_CACHE' not in os.environ:

View File

@ -2,7 +2,7 @@
import torch
from modules import shared
from ldm_patched.modules import args_parser
from ldm_patched.modules import model_management
@ -56,14 +56,12 @@ def get_new_stream():
return None
if shared.opts.use_non_streamlined_lowvram:
current_stream = None
mover_stream = None
using_stream = False
else:
current_stream = None
mover_stream = None
using_stream = False
if args_parser.args.cuda_stream:
current_stream = get_current_stream()
mover_stream = get_new_stream()
using_stream = current_stream is not None and mover_stream is not None
if not using_stream:
print('Stream is not used.')