Merge upstream PR 14855
This commit is contained in:
parent
95ddac3117
commit
638ee43bf1
@ -6,6 +6,12 @@ import torch
|
||||
import ldm_patched.modules.model_management
|
||||
import contextlib
|
||||
|
||||
from modules_forge import stream
|
||||
|
||||
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855/files
|
||||
gc = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
@ -25,12 +31,44 @@ def use_patched_ops(operations):
|
||||
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
return weight, bias
|
||||
context = contextlib.nullcontext
|
||||
signal = None
|
||||
|
||||
if stream.using_stream:
|
||||
context = stream.stream_context()
|
||||
|
||||
with context(stream.mover_stream):
|
||||
bias = None
|
||||
non_blocking = ldm_patched.modules.model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
if stream.using_stream:
|
||||
signal = stream.mover_stream.record_event()
|
||||
return weight, bias, signal
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def main_thread_worker(weight, bias, signal):
|
||||
if not stream.using_stream or signal is None:
|
||||
yield
|
||||
return
|
||||
|
||||
with stream.stream_context()(stream.current_stream):
|
||||
stream.current_stream.wait_event(signal)
|
||||
yield
|
||||
finished_signal = stream.current_stream.record_event()
|
||||
gc[id(finished_signal)] = (weight, bias, finished_signal)
|
||||
|
||||
garbage = []
|
||||
for k, (w, b, s) in gc.items():
|
||||
if s.query():
|
||||
garbage.append(k)
|
||||
|
||||
for k in garbage:
|
||||
del gc[k]
|
||||
return
|
||||
|
||||
|
||||
class disable_weight_init:
|
||||
@ -40,8 +78,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@ -55,8 +94,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@ -70,8 +110,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@ -85,8 +126,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
@ -101,8 +143,9 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_ldm_patched_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
weight, bias, signal = cast_bias_weight(self, input)
|
||||
with main_thread_worker(weight, bias, signal):
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.ldm_patched_cast_weights:
|
||||
|
@ -216,6 +216,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("(Requires restart in Forge.) Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
||||
from ldm_patched.modules.samplers import sampling_function
|
||||
from ldm_patched.modules import model_management
|
||||
from modules_forge.stream import synchronize_current_stream
|
||||
|
||||
|
||||
def cond_from_a1111_to_patched_ldm(cond):
|
||||
@ -113,4 +114,5 @@ def sampling_prepare(unet, x):
|
||||
def sampling_cleanup(unet):
|
||||
for cnet in unet.list_controlnets():
|
||||
cnet.cleanup()
|
||||
synchronize_current_stream()
|
||||
return
|
||||
|
56
modules_forge/stream.py
Normal file
56
modules_forge/stream.py
Normal file
@ -0,0 +1,56 @@
|
||||
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14855
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def stream_context():
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.stream
|
||||
|
||||
if model_management.is_intel_xpu():
|
||||
return torch.xpu.stream
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_current_stream():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.current_stream(torch.device(torch.cuda.current_device()))
|
||||
if model_management.is_intel_xpu():
|
||||
return torch.xpu.current_stream(torch.device("xpu"))
|
||||
except:
|
||||
pass
|
||||
print('Stream is not used.')
|
||||
return None
|
||||
|
||||
|
||||
def get_new_stream():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.Stream(torch.device(torch.cuda.current_device()))
|
||||
if model_management.is_intel_xpu():
|
||||
return torch.xpu.Stream(torch.device("xpu"))
|
||||
except:
|
||||
pass
|
||||
print('Stream is not used.')
|
||||
return None
|
||||
|
||||
|
||||
def synchronize_current_stream():
|
||||
global current_stream
|
||||
if current_stream is not None:
|
||||
current_stream.synchronize()
|
||||
|
||||
|
||||
if shared.opts.use_non_streamlined_lowvram:
|
||||
current_stream = None
|
||||
mover_stream = None
|
||||
using_stream = False
|
||||
else:
|
||||
current_stream = get_current_stream()
|
||||
mover_stream = get_new_stream()
|
||||
using_stream = current_stream is not None and mover_stream is not None
|
Loading…
Reference in New Issue
Block a user