f06ba8e60b
This will move all major gradio calls into the main thread rather than random gradio threads. This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments. In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
210 lines
7.9 KiB
Python
210 lines
7.9 KiB
Python
import json
|
|
import os
|
|
import signal
|
|
import sys
|
|
import re
|
|
|
|
from modules.timer import startup_timer
|
|
|
|
|
|
def gradio_server_name():
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
if cmd_opts.server_name:
|
|
return cmd_opts.server_name
|
|
else:
|
|
return "0.0.0.0" if cmd_opts.listen else None
|
|
|
|
|
|
def fix_torch_version():
|
|
import torch
|
|
|
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
|
torch.__long_version__ = torch.__version__
|
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
|
|
|
|
|
def fix_asyncio_event_loop_policy():
|
|
"""
|
|
The default `asyncio` event loop policy only automatically creates
|
|
event loops in the main threads. Other threads must create event
|
|
loops explicitly or `asyncio.get_event_loop` (and therefore
|
|
`.IOLoop.current`) will fail. Installing this policy allows event
|
|
loops to be created automatically on any thread, matching the
|
|
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
|
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
|
# interface for composing policies so pick the right base.
|
|
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
|
else:
|
|
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
|
|
|
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
|
"""Event loop policy that allows loop creation on any thread.
|
|
Usage::
|
|
|
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
|
"""
|
|
|
|
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
|
try:
|
|
return super().get_event_loop()
|
|
except (RuntimeError, AssertionError):
|
|
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
|
# and changed to a RuntimeError in 3.4.3.
|
|
# "There is no current event loop in thread %r"
|
|
loop = self.new_event_loop()
|
|
self.set_event_loop(loop)
|
|
return loop
|
|
|
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
|
|
|
|
|
def restore_config_state_file():
|
|
from modules import shared, config_states
|
|
|
|
config_state_file = shared.opts.restore_config_state_file
|
|
if config_state_file == "":
|
|
return
|
|
|
|
shared.opts.restore_config_state_file = ""
|
|
shared.opts.save(shared.config_filename)
|
|
|
|
if os.path.isfile(config_state_file):
|
|
print(f"*** About to restore extension state from file: {config_state_file}")
|
|
with open(config_state_file, "r", encoding="utf-8") as f:
|
|
config_state = json.load(f)
|
|
config_states.restore_extension_config(config_state)
|
|
startup_timer.record("restore extension config")
|
|
elif config_state_file:
|
|
print(f"!!! Config state backup not found: {config_state_file}")
|
|
|
|
|
|
def validate_tls_options():
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
|
return
|
|
|
|
try:
|
|
if not os.path.exists(cmd_opts.tls_keyfile):
|
|
print("Invalid path to TLS keyfile given")
|
|
if not os.path.exists(cmd_opts.tls_certfile):
|
|
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
|
except TypeError:
|
|
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
|
print("TLS setup invalid, running webui without TLS")
|
|
else:
|
|
print("Running with TLS")
|
|
startup_timer.record("TLS")
|
|
|
|
|
|
def get_gradio_auth_creds():
|
|
"""
|
|
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
|
an iterable of (username, password) tuples.
|
|
"""
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
def process_credential_line(s):
|
|
s = s.strip()
|
|
if not s:
|
|
return None
|
|
return tuple(s.split(':', 1))
|
|
|
|
if cmd_opts.gradio_auth:
|
|
for cred in cmd_opts.gradio_auth.split(','):
|
|
cred = process_credential_line(cred)
|
|
if cred:
|
|
yield cred
|
|
|
|
if cmd_opts.gradio_auth_path:
|
|
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
|
for line in file.readlines():
|
|
for cred in line.strip().split(','):
|
|
cred = process_credential_line(cred)
|
|
if cred:
|
|
yield cred
|
|
|
|
|
|
def dumpstacks():
|
|
import threading
|
|
import traceback
|
|
|
|
id2name = {th.ident: th.name for th in threading.enumerate()}
|
|
code = []
|
|
for threadId, stack in sys._current_frames().items():
|
|
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
|
for filename, lineno, name, line in traceback.extract_stack(stack):
|
|
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
|
if line:
|
|
code.append(" " + line.strip())
|
|
|
|
print("\n".join(code))
|
|
|
|
|
|
def configure_sigint_handler():
|
|
# make the program just exit at ctrl+c without waiting for anything
|
|
|
|
from modules import shared
|
|
|
|
def sigint_handler(sig, frame):
|
|
print(f'Interrupted with signal {sig} in {frame}')
|
|
|
|
if shared.opts.dump_stacks_on_signal:
|
|
dumpstacks()
|
|
|
|
os._exit(0)
|
|
|
|
if not os.environ.get("COVERAGE_RUN"):
|
|
# Don't install the immediate-quit handler when running under coverage,
|
|
# as then the coverage report won't be generated.
|
|
signal.signal(signal.SIGINT, sigint_handler)
|
|
|
|
|
|
def configure_opts_onchange():
|
|
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
|
from modules.call_queue import wrap_queued_call
|
|
from modules_forge import main_thread
|
|
|
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
|
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
|
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
|
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
|
startup_timer.record("opts onchange")
|
|
|
|
|
|
def setup_middleware(app):
|
|
from starlette.middleware.gzip import GZipMiddleware
|
|
|
|
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
|
configure_cors_middleware(app)
|
|
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
|
|
|
|
|
def configure_cors_middleware(app):
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
cors_options = {
|
|
"allow_methods": ["*"],
|
|
"allow_headers": ["*"],
|
|
"allow_credentials": True,
|
|
}
|
|
if cmd_opts.cors_allow_origins:
|
|
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
|
if cmd_opts.cors_allow_origins_regex:
|
|
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
|
app.add_middleware(CORSMiddleware, **cors_options)
|
|
|