Significantly reduce thread abuse for faster model moving
This will move all major gradio calls into the main thread rather than random gradio threads. This ensures that all torch.module.to() are performed in main thread to completely possible avoid GPU fragments. In my test now model moving is 0.7 ~ 1.2 seconds faster, which means all 6GB/8GB VRAM users will get 0.7 ~ 1.2 seconds faster per image on SDXL.
This commit is contained in:
parent
291ec743b6
commit
f06ba8e60b
@ -15,6 +15,7 @@ import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
import modules.scripts
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
return batch_results
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
@ -244,3 +245,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
processed.images = []
|
||||
|
||||
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args)
|
||||
|
@ -149,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Accesses shared.sd_model property to load model.
|
||||
After it's available, if it has been loaded before this access by some extension,
|
||||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
from modules import devices
|
||||
devices.torch_npu_set_device()
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
devices.first_time_calculation()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
from modules_forge import main_thread
|
||||
import modules.sd_models
|
||||
main_thread.async_run(modules.sd_models.model_data.get_sd_model)
|
||||
|
||||
from modules import shared_items
|
||||
shared_items.reload_hypernetworks()
|
||||
|
@ -170,10 +170,11 @@ def configure_sigint_handler():
|
||||
def configure_opts_onchange():
|
||||
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
||||
from modules.call_queue import wrap_queued_call
|
||||
from modules_forge import main_thread
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
|
@ -511,6 +511,11 @@ def start():
|
||||
else:
|
||||
webui.webui()
|
||||
|
||||
from modules_forge import main_thread
|
||||
|
||||
main_thread.loop()
|
||||
return
|
||||
|
||||
|
||||
def dump_sysinfo():
|
||||
from modules import sysinfo
|
||||
|
@ -2,6 +2,8 @@ import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import torch
|
||||
|
||||
from modules import errors, shared, devices
|
||||
from typing import Optional
|
||||
@ -134,6 +136,7 @@ class State:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
@torch.inference_mode()
|
||||
def set_current_image(self):
|
||||
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||
if not shared.parallel_processing_allowed:
|
||||
@ -142,6 +145,7 @@ class State:
|
||||
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
@torch.inference_mode()
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
@ -156,11 +160,14 @@ class State:
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# traceback.print_exc()
|
||||
# print(e)
|
||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||
# we silently ignore this error
|
||||
errors.record_exception()
|
||||
|
||||
@torch.inference_mode()
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
|
@ -9,6 +9,7 @@ import modules.shared as shared
|
||||
from modules.ui import plaintext_to_html
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||
@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
||||
return p
|
||||
|
||||
|
||||
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
assert len(gallery) > 0, 'No image to upscale'
|
||||
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
|
||||
|
||||
@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g
|
||||
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def txt2img(id_task: str, request: gr.Request, *args):
|
||||
def txt2img_function(id_task: str, request: gr.Request, *args):
|
||||
p = txt2img_create_processing(id_task, request, *args)
|
||||
|
||||
with closing(p):
|
||||
@ -119,3 +120,11 @@ def txt2img(id_task: str, request: gr.Request, *args):
|
||||
processed.images = []
|
||||
|
||||
return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||
|
||||
|
||||
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||
return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args)
|
||||
|
||||
|
||||
def txt2img(id_task: str, request: gr.Request, *args):
|
||||
return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args)
|
||||
|
68
modules_forge/main_thread.py
Normal file
68
modules_forge/main_thread.py
Normal file
@ -0,0 +1,68 @@
|
||||
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
|
||||
# Other gradio calls (like those from extensions) are not influenced.
|
||||
# By using one single thread to process all major calls, model moving is significantly faster.
|
||||
|
||||
|
||||
import time
|
||||
import traceback
|
||||
import threading
|
||||
|
||||
|
||||
lock = threading.Lock()
|
||||
last_id = 0
|
||||
waiting_list = []
|
||||
finished_list = []
|
||||
|
||||
|
||||
class Task:
|
||||
def __init__(self, task_id, func, args, kwargs):
|
||||
self.task_id = task_id
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.result = None
|
||||
|
||||
def work(self):
|
||||
self.result = self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def loop():
|
||||
global lock, last_id, waiting_list, finished_list
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
if len(waiting_list) > 0:
|
||||
with lock:
|
||||
task = waiting_list.pop(0)
|
||||
try:
|
||||
task.work()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(e)
|
||||
with lock:
|
||||
finished_list.append(task)
|
||||
|
||||
|
||||
def async_run(func, *args, **kwargs):
|
||||
global lock, last_id, waiting_list, finished_list
|
||||
with lock:
|
||||
last_id += 1
|
||||
new_task = Task(task_id=last_id, func=func, args=args, kwargs=kwargs)
|
||||
waiting_list.append(new_task)
|
||||
return new_task.task_id
|
||||
|
||||
|
||||
def run_and_wait_result(func, *args, **kwargs):
|
||||
global lock, last_id, waiting_list, finished_list
|
||||
current_id = async_run(func, *args, **kwargs)
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
finished_task = None
|
||||
for t in finished_list.copy(): # thread safe shallow copy without needing a lock
|
||||
if t.task_id == current_id:
|
||||
finished_task = t
|
||||
break
|
||||
if finished_task is not None:
|
||||
with lock:
|
||||
finished_list.remove(finished_task)
|
||||
return finished_task.result
|
||||
|
23
webui.py
23
webui.py
@ -6,8 +6,10 @@ import time
|
||||
from modules import timer
|
||||
from modules import initialize_util
|
||||
from modules import initialize
|
||||
|
||||
from threading import Thread
|
||||
from modules_forge.initialization import initialize_forge
|
||||
from modules_forge import main_thread
|
||||
|
||||
|
||||
startup_timer = timer.startup_timer
|
||||
startup_timer.record("launcher")
|
||||
@ -18,6 +20,8 @@ initialize.imports()
|
||||
|
||||
initialize.check_versions()
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
|
||||
def create_api(app):
|
||||
from modules.api.api import Api
|
||||
@ -27,12 +31,10 @@ def create_api(app):
|
||||
return api
|
||||
|
||||
|
||||
def api_only():
|
||||
def api_only_worker():
|
||||
from fastapi import FastAPI
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
initialize.initialize()
|
||||
|
||||
app = FastAPI()
|
||||
initialize_util.setup_middleware(app)
|
||||
api = create_api(app)
|
||||
@ -49,11 +51,10 @@ def api_only():
|
||||
)
|
||||
|
||||
|
||||
def webui():
|
||||
def webui_worker():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
|
||||
from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks
|
||||
|
||||
@ -157,6 +158,14 @@ def webui():
|
||||
initialize.initialize_rest(reload_script_modules=True)
|
||||
|
||||
|
||||
def api_only():
|
||||
Thread(target=api_only_worker, daemon=True).start()
|
||||
|
||||
|
||||
def webui():
|
||||
Thread(target=webui_worker, daemon=True).start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
@ -164,3 +173,5 @@ if __name__ == "__main__":
|
||||
api_only()
|
||||
else:
|
||||
webui()
|
||||
|
||||
main_thread.loop()
|
||||
|
Loading…
Reference in New Issue
Block a user