Merge upstream PR 14855

This commit is contained in:
lllyasviel 2024-02-21 23:59:40 -08:00
parent 95ddac3117
commit 638ee43bf1
4 changed files with 118 additions and 16 deletions

View File

@ -6,6 +6,12 @@ import torch
import ldm_patched.modules.model_management
import contextlib
from modules_forge import stream
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
gc = {}
@contextlib.contextmanager
def use_patched_ops(operations):
@ -25,12 +31,44 @@ def use_patched_ops(operations):
def cast_bias_weight(s, input):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
context = contextlib.nullcontext
signal = None
if stream.using_stream:
context = stream.stream_context()
with context(stream.mover_stream):
bias = None
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if stream.using_stream:
signal = stream.mover_stream.record_event()
return weight, bias, signal
@contextlib.contextmanager
def main_thread_worker(weight, bias, signal):
if not stream.using_stream or signal is None:
yield
return
with stream.stream_context()(stream.current_stream):
stream.current_stream.wait_event(signal)
yield
finished_signal = stream.current_stream.record_event()
gc[id(finished_signal)] = (weight, bias, finished_signal)
garbage = []
for k, (w, b, s) in gc.items():
if s.query():
garbage.append(k)
for k in garbage:
del gc[k]
return
class disable_weight_init:
@ -40,8 +78,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@ -55,8 +94,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@ -70,8 +110,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@ -85,8 +126,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:
@ -101,8 +143,9 @@ class disable_weight_init:
return None
def forward_ldm_patched_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
weight, bias, signal = cast_bias_weight(self, input)
with main_thread_worker(weight, bias, signal):
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.ldm_patched_cast_weights:

View File

@ -216,6 +216,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {

View File

@ -2,6 +2,7 @@ import torch
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
from ldm_patched.modules.samplers import sampling_function
from ldm_patched.modules import model_management
from modules_forge.stream import synchronize_current_stream
def cond_from_a1111_to_patched_ldm(cond):
@ -113,4 +114,5 @@ def sampling_prepare(unet, x):
def sampling_cleanup(unet):
for cnet in unet.list_controlnets():
cnet.cleanup()
synchronize_current_stream()
return

56
modules_forge/stream.py Normal file
View File

@ -0,0 +1,56 @@
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855
import torch
from modules import shared
from ldm_patched.modules import model_management
def stream_context():
if torch.cuda.is_available():
return torch.cuda.stream
if model_management.is_intel_xpu():
return torch.xpu.stream
return None
def get_current_stream():
try:
if torch.cuda.is_available():
return torch.cuda.current_stream(torch.device(torch.cuda.current_device()))
if model_management.is_intel_xpu():
return torch.xpu.current_stream(torch.device("xpu"))
except:
pass
print('Stream is not used.')
return None
def get_new_stream():
try:
if torch.cuda.is_available():
return torch.cuda.Stream(torch.device(torch.cuda.current_device()))
if model_management.is_intel_xpu():
return torch.xpu.Stream(torch.device("xpu"))
except:
pass
print('Stream is not used.')
return None
def synchronize_current_stream():
global current_stream
if current_stream is not None:
current_stream.synchronize()
if shared.opts.use_non_streamlined_lowvram:
current_stream = None
mover_stream = None
using_stream = False
else:
current_stream = get_current_stream()
mover_stream = get_new_stream()
using_stream = current_stream is not None and mover_stream is not None