Add optimization --cuda-stream

See also the readme for more details
This commit is contained in:
lllyasviel 2024-02-24 14:00:48 -08:00 committed by GitHub
parent 0f09d98814
commit 434ca2169f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 63 additions and 73 deletions

View File

@ -88,6 +88,8 @@ Without any cmd flag, Forge can run SDXL with 4GB vram and SD1.5 with 2GB vram.
3. `--cuda-malloc` (This flag will make things **faster** but more risky). This will ask pytorch to use *cudaMallocAsync* for tensor malloc. On some profilers I can observe performance gain at millisecond level, but the real speed up on most my devices are often unnoticed (about or less than 0.1 second per image). This cannot be set as default because many users reported issues that the async malloc will crash the program. Users need to enable this cmd flag at their own risk. 3. `--cuda-malloc` (This flag will make things **faster** but more risky). This will ask pytorch to use *cudaMallocAsync* for tensor malloc. On some profilers I can observe performance gain at millisecond level, but the real speed up on most my devices are often unnoticed (about or less than 0.1 second per image). This cannot be set as default because many users reported issues that the async malloc will crash the program. Users need to enable this cmd flag at their own risk.
4. `--cuda-stream` (This flag will make things **faster** but more risky). This will use pytorch CUDA streams (a special type of thread on GPU) to move models and compute tensors simultaneously. This can almost eliminate all model moving time, and speed up SDXL on 30XX/40XX devices with small VRAM (eg, RTX 4050 6GB, RTX 3060 Laptop 6GB, etc) by about 15\% to 25\%. However, this unfortunately cannot be set as default because I observe higher possibility of pure black images (Nan outputs) on 2060, and higher chance of OOM on 1080 and 2060. When the resolution is large, there is a chance that the computation time of one single attention layer is longer than the time for moving entire model to GPU. When that happens, the next attention layer will OOM since the GPU is filled with the entire model, and no remaining space is available for computing another attention layer. Most overhead detecting methods are not robust enough to be reliable on old devices (in my tests). Users need to enable this cmd flag at their own risk.
If you really want to play with cmd flags, you can additionally control the GPU with: If you really want to play with cmd flags, you can additionally control the GPU with:
(extreme VRAM cases) (extreme VRAM cases)

View File

@ -116,6 +116,7 @@ parser.add_argument("--disable-server-info", action="store_true")
parser.add_argument("--multi-user", action="store_true") parser.add_argument("--multi-user", action="store_true")
parser.add_argument("--cuda-malloc", action="store_true") parser.add_argument("--cuda-malloc", action="store_true")
parser.add_argument("--cuda-stream", action="store_true")
parser.add_argument("--pin-shared-memory", action="store_true") parser.add_argument("--pin-shared-memory", action="store_true")
if ldm_patched.modules.options.args_parsing: if ldm_patched.modules.options.args_parsing:

View File

@ -14,7 +14,7 @@ import ldm_patched.modules.ops
import ldm_patched.controlnet.cldm import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter import ldm_patched.t2ia.adapter
from ldm_patched.modules.ops import main_thread_worker from ldm_patched.modules.ops import main_stream_worker
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -306,7 +306,7 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
@ -347,7 +347,7 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input) weight, bias, signal = ldm_patched.modules.ops.cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:

View File

@ -215,9 +215,9 @@ class BaseModel(torch.nn.Module):
dtype_size = ldm_patched.modules.model_management.dtype_size(dtype) dtype_size = ldm_patched.modules.model_management.dtype_size(dtype)
if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention(): if ldm_patched.modules.model_management.xformers_enabled() or ldm_patched.modules.model_management.pytorch_attention_flash_attention():
scaler = 1.25 scaler = 1.28
else: else:
scaler = 1.75 scaler = 1.65
if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32": if ldm_patched.ldm.modules.attention._ATTN_PRECISION == "fp32":
dtype_size = 4 dtype_size = 4

View File

@ -6,6 +6,7 @@ import time
import psutil import psutil
from enum import Enum from enum import Enum
from ldm_patched.modules.args_parser import args from ldm_patched.modules.args_parser import args
from modules_forge import stream
import ldm_patched.modules.utils import ldm_patched.modules.utils
import torch import torch
import sys import sys
@ -277,6 +278,8 @@ if 'rtx' in torch_device_name.lower():
print('Hint: your device supports --pin-shared-memory for potential speed improvements.') print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
if not args.cuda_malloc: if not args.cuda_malloc:
print('Hint: your device supports --cuda-malloc for potential speed improvements.') print('Hint: your device supports --cuda-malloc for potential speed improvements.')
if not args.cuda_stream:
print('Hint: your device supports --cuda-stream for potential speed improvements.')
print("VAE dtype:", VAE_DTYPE) print("VAE dtype:", VAE_DTYPE)
@ -326,7 +329,8 @@ class LoadedModel:
raise e raise e
if not disable_async_load: if not disable_async_load:
print("[Memory Management] Requested Async Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024)) flag = 'ASYNC' if stream.using_stream else 'SYNC'
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
real_async_memory = 0 real_async_memory = 0
mem_counter = 0 mem_counter = 0
for m in self.real_model.modules(): for m in self.real_model.modules():
@ -345,9 +349,9 @@ class LoadedModel:
elif hasattr(m, "weight"): elif hasattr(m, "weight"):
m.to(self.device) m.to(self.device)
mem_counter += module_size(m) mem_counter += module_size(m)
print("[Memory Management] Async Loader Disabled for ", m) print(f"[Memory Management] {flag} Loader Disabled for ", m)
print("[Async Memory Management] Parameters Loaded to Async Stream (MB) = ", real_async_memory / (1024 * 1024)) print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
print("[Async Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024)) print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
self.model_accelerated = True self.model_accelerated = True
@ -372,7 +376,7 @@ class LoadedModel:
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model and self.memory_required == other.memory_required return self.model is other.model # and self.memory_required == other.memory_required
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
@ -383,7 +387,8 @@ def unload_model_clones(model):
if model.is_clone(current_loaded_models[i].model): if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
print(f"Reuse {len(to_unload)} loaded models") if len(to_unload) > 0:
print(f"Reuse {len(to_unload)} loaded models")
for i in to_unload: for i in to_unload:
current_loaded_models.pop(i).model_unload(avoid_model_moving=True) current_loaded_models.pop(i).model_unload(avoid_model_moving=True)
@ -414,9 +419,7 @@ def load_models_gpu(models, memory_required=0):
global vram_state global vram_state
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
extra_mem = max(minimum_inference_memory(), memory_required)
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
models_to_load = [] models_to_load = []
models_already_loaded = [] models_already_loaded = []
@ -439,7 +442,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
moving_time = time.perf_counter() - execution_start_time moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.01: if moving_time > 0.1:
print(f'Memory cleanup has taken {moving_time:.2f} seconds') print(f'Memory cleanup has taken {moving_time:.2f} seconds')
return return
@ -466,25 +469,25 @@ def load_models_gpu(models, memory_required=0):
async_kept_memory = -1 async_kept_memory = -1
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_memory = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
estimated_memory_remaining = current_free_mem - model_size - extra_mem minimal_inference_memory = minimum_inference_memory()
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory
print("[Memory Management] Current Free Memory (MB) = ", current_free_mem / (1024 * 1024)) print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024))
print("[Memory Management] Model Memory (MB) = ", model_size / (1024 * 1024)) print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024))
print("[Memory Management] Estimated Inference Memory (MB) = ", extra_mem / (1024 * 1024)) print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024))
print("[Memory Management] Estimated Remaining Memory (MB) = ", estimated_memory_remaining / (1024 * 1024)) print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024))
if estimated_memory_remaining < 0: if estimated_remaining_memory < 0:
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
async_overhead_memory = 1024 * 1024 * 1024 async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3
async_kept_memory = current_free_mem - extra_mem - async_overhead_memory
async_kept_memory = int(max(0, async_kept_memory)) async_kept_memory = int(max(0, async_kept_memory))
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
async_kept_memory = 0 async_kept_memory = 0
cur_loaded_model = loaded_model.model_load(async_kept_memory) loaded_model.model_load(async_kept_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time moving_time = time.perf_counter() - execution_start_time

View File

@ -10,7 +10,7 @@ from modules_forge import stream
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files # https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
gc = {} stash = {}
@contextlib.contextmanager @contextlib.contextmanager
@ -31,26 +31,25 @@ def use_patched_ops(operations):
def cast_bias_weight(s, input): def cast_bias_weight(s, input):
context = contextlib.nullcontext weight, bias, signal = None, None, None
signal = None non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if stream.using_stream: if stream.using_stream:
context = stream.stream_context() with stream.stream_context()(stream.mover_stream):
if s.bias is not None:
with context(stream.mover_stream): bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
bias = None weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device) signal = stream.mover_stream.record_event()
else:
if s.bias is not None: if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if stream.using_stream:
signal = stream.mover_stream.record_event()
return weight, bias, signal return weight, bias, signal
@contextlib.contextmanager @contextlib.contextmanager
def main_thread_worker(weight, bias, signal): def main_stream_worker(weight, bias, signal):
if not stream.using_stream or signal is None: if not stream.using_stream or signal is None:
yield yield
return return
@ -59,40 +58,25 @@ def main_thread_worker(weight, bias, signal):
stream.current_stream.wait_event(signal) stream.current_stream.wait_event(signal)
yield yield
finished_signal = stream.current_stream.record_event() finished_signal = stream.current_stream.record_event()
size = weight.element_size() * weight.nelement() stash[id(finished_signal)] = (weight, bias, finished_signal)
if bias is not None:
size += bias.element_size() * bias.nelement()
gc[id(finished_signal)] = (weight, bias, finished_signal, size)
overhead = sum([l for k, (w, b, s, l) in gc.items()])
if overhead > 512 * 1024 * 1024:
stream.mover_stream.synchronize()
stream.current_stream.synchronize()
garbage = [] garbage = []
for k, (w, b, s, l) in gc.items(): for k, (w, b, s) in stash.items():
if s.query(): if s.query():
garbage.append(k) garbage.append(k)
for k in garbage: for k in garbage:
del gc[k] del stash[k]
return return
def cleanup_cache(): def cleanup_cache():
global gc if not stream.using_stream:
return
if stream.current_stream is not None: stream.current_stream.synchronize()
with stream.stream_context()(stream.current_stream): stream.mover_stream.synchronize()
for k, (w, b, s, l) in gc.items(): stash.clear()
stream.current_stream.wait_event(s)
stream.current_stream.synchronize()
gc.clear()
if stream.mover_stream is not None:
stream.mover_stream.synchronize()
return return
@ -104,7 +88,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input): def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input) weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -120,7 +104,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input): def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input) weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -136,7 +120,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input): def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input) weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -152,7 +136,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input): def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input) weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@ -169,7 +153,7 @@ class disable_weight_init:
def forward_ldm_patched_cast_weights(self, input): def forward_ldm_patched_cast_weights(self, input):
weight, bias, signal = cast_bias_weight(self, input) weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal): with main_stream_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):

View File

@ -216,7 +216,6 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {

View File

@ -59,6 +59,9 @@ def initialize_forge():
import modules_forge.patch_basic import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics() modules_forge.patch_basic.patch_all_basics()
from modules_forge import stream
print('CUDA Stream Activated: ', stream.using_stream)
from modules_forge.shared import diffusers_dir from modules_forge.shared import diffusers_dir
if 'TRANSFORMERS_CACHE' not in os.environ: if 'TRANSFORMERS_CACHE' not in os.environ:

View File

@ -2,7 +2,7 @@
import torch import torch
from modules import shared from ldm_patched.modules import args_parser
from ldm_patched.modules import model_management from ldm_patched.modules import model_management
@ -56,14 +56,12 @@ def get_new_stream():
return None return None
if shared.opts.use_non_streamlined_lowvram: current_stream = None
current_stream = None mover_stream = None
mover_stream = None using_stream = False
using_stream = False
else: if args_parser.args.cuda_stream:
current_stream = get_current_stream() current_stream = get_current_stream()
mover_stream = get_new_stream() mover_stream = get_new_stream()
using_stream = current_stream is not None and mover_stream is not None using_stream = current_stream is not None and mover_stream is not None
if not using_stream:
print('Stream is not used.')