f06ba8e60b
This will move all major gradio calls into the main thread rather than random gradio threads. This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments. In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
178 lines
5.6 KiB
Python
178 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
|
|
from modules import timer
|
|
from modules import initialize_util
|
|
from modules import initialize
|
|
from threading import Thread
|
|
from modules_forge.initialization import initialize_forge
|
|
from modules_forge import main_thread
|
|
|
|
|
|
startup_timer = timer.startup_timer
|
|
startup_timer.record("launcher")
|
|
|
|
initialize_forge()
|
|
|
|
initialize.imports()
|
|
|
|
initialize.check_versions()
|
|
|
|
initialize.initialize()
|
|
|
|
|
|
def create_api(app):
|
|
from modules.api.api import Api
|
|
from modules.call_queue import queue_lock
|
|
|
|
api = Api(app, queue_lock)
|
|
return api
|
|
|
|
|
|
def api_only_worker():
|
|
from fastapi import FastAPI
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
app = FastAPI()
|
|
initialize_util.setup_middleware(app)
|
|
api = create_api(app)
|
|
|
|
from modules import script_callbacks
|
|
script_callbacks.before_ui_callback()
|
|
script_callbacks.app_started_callback(None, app)
|
|
|
|
print(f"Startup time: {startup_timer.summary()}.")
|
|
api.launch(
|
|
server_name=initialize_util.gradio_server_name(),
|
|
port=cmd_opts.port if cmd_opts.port else 7861,
|
|
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
|
|
)
|
|
|
|
|
|
def webui_worker():
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
launch_api = cmd_opts.api
|
|
|
|
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
|
|
|
|
while 1:
|
|
if shared.opts.clean_temp_dir_at_start:
|
|
ui_tempdir.cleanup_tmpdr()
|
|
startup_timer.record("cleanup temp dir")
|
|
|
|
script_callbacks.before_ui_callback()
|
|
startup_timer.record("scripts before_ui_callback")
|
|
|
|
shared.demo = ui.create_ui()
|
|
startup_timer.record("create ui")
|
|
|
|
if not cmd_opts.no_gradio_queue:
|
|
shared.demo.queue(64)
|
|
|
|
gradio_auth_creds = list(initialize_util.get_gradio_auth_creds()) or None
|
|
|
|
auto_launch_browser = False
|
|
if os.getenv('SD_WEBUI_RESTARTING') != '1':
|
|
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
|
|
auto_launch_browser = True
|
|
elif shared.opts.auto_launch_browser == "Local":
|
|
auto_launch_browser = not cmd_opts.webui_is_non_local
|
|
|
|
app, local_url, share_url = shared.demo.launch(
|
|
share=cmd_opts.share,
|
|
server_name=initialize_util.gradio_server_name(),
|
|
server_port=cmd_opts.port,
|
|
ssl_keyfile=cmd_opts.tls_keyfile,
|
|
ssl_certfile=cmd_opts.tls_certfile,
|
|
ssl_verify=cmd_opts.disable_tls_verify,
|
|
debug=cmd_opts.gradio_debug,
|
|
auth=gradio_auth_creds,
|
|
inbrowser=auto_launch_browser,
|
|
prevent_thread_lock=True,
|
|
allowed_paths=cmd_opts.gradio_allowed_path,
|
|
app_kwargs={
|
|
"docs_url": "/docs",
|
|
"redoc_url": "/redoc",
|
|
},
|
|
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
|
|
)
|
|
|
|
startup_timer.record("gradio launch")
|
|
|
|
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
|
|
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
|
|
# running web ui and do whatever the attacker wants, including installing an extension and
|
|
# running its code. We disable this here. Suggested by RyotaK.
|
|
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
|
|
|
|
initialize_util.setup_middleware(app)
|
|
|
|
progress.setup_progress_api(app)
|
|
ui.setup_ui_api(app)
|
|
|
|
if launch_api:
|
|
create_api(app)
|
|
|
|
ui_extra_networks.add_pages_to_demo(app)
|
|
|
|
startup_timer.record("add APIs")
|
|
|
|
with startup_timer.subcategory("app_started_callback"):
|
|
script_callbacks.app_started_callback(shared.demo, app)
|
|
|
|
timer.startup_record = startup_timer.dump()
|
|
print(f"Startup time: {startup_timer.summary()}.")
|
|
|
|
try:
|
|
while True:
|
|
server_command = shared.state.wait_for_server_command(timeout=5)
|
|
if server_command:
|
|
if server_command in ("stop", "restart"):
|
|
break
|
|
else:
|
|
print(f"Unknown server command: {server_command}")
|
|
except KeyboardInterrupt:
|
|
print('Caught KeyboardInterrupt, stopping...')
|
|
server_command = "stop"
|
|
|
|
if server_command == "stop":
|
|
print("Stopping server...")
|
|
# If we catch a keyboard interrupt, we want to stop the server and exit.
|
|
shared.demo.close()
|
|
break
|
|
|
|
# disable auto launch webui in browser for subsequent UI Reload
|
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
|
|
|
print('Restarting UI...')
|
|
shared.demo.close()
|
|
time.sleep(0.5)
|
|
startup_timer.reset()
|
|
script_callbacks.app_reload_callback()
|
|
startup_timer.record("app reload callback")
|
|
script_callbacks.script_unloaded_callback()
|
|
startup_timer.record("scripts unloaded callback")
|
|
initialize.initialize_rest(reload_script_modules=True)
|
|
|
|
|
|
def api_only():
|
|
Thread(target=api_only_worker, daemon=True).start()
|
|
|
|
|
|
def webui():
|
|
Thread(target=webui_worker, daemon=True).start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from modules.shared_cmd_options import cmd_opts
|
|
|
|
if cmd_opts.nowebui:
|
|
api_only()
|
|
else:
|
|
webui()
|
|
|
|
main_thread.loop()
|