diff --git a/modules/img2img.py b/modules/img2img.py index 0d4dca3f..6e9729a4 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,6 +15,7 @@ import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts +from modules_forge import main_thread def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -146,7 +147,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img_function(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -244,3 +245,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s processed.images = [] return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") + + +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): + return main_thread.run_and_wait_result(img2img_function, id_task, mode, prompt, negative_prompt, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, sampler_name, mask_blur, mask_alpha, inpainting_fill, n_iter, batch_size, cfg_scale, image_cfg_scale, denoising_strength, selected_scale_tab, height, width, scale_by, resize_mode, inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings_texts, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, request, *args) diff --git a/modules/initialize.py b/modules/initialize.py index 6d4d4149..180e1f8e 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -149,24 +149,9 @@ def initialize_rest(*, reload_script_modules=False): sd_unet.list_unets() startup_timer.record("scripts list_unets") - def load_model(): - """ - Accesses shared.sd_model property to load model. - After it's available, if it has been loaded before this access by some extension, - its optimization may be None because the list of optimizaers has neet been filled - by that time, so we apply optimization again. - """ - from modules import devices - devices.torch_npu_set_device() - - shared.sd_model # noqa: B018 - - if sd_hijack.current_optimizer is None: - sd_hijack.apply_optimizations() - - devices.first_time_calculation() - if not shared.cmd_opts.skip_load_model_at_start: - Thread(target=load_model).start() + from modules_forge import main_thread + import modules.sd_models + main_thread.async_run(modules.sd_models.model_data.get_sd_model) from modules import shared_items shared_items.reload_hypernetworks() diff --git a/modules/initialize_util.py b/modules/initialize_util.py index b6767138..7801d932 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -170,10 +170,11 @@ def configure_sigint_handler(): def configure_opts_onchange(): from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack from modules.call_queue import wrap_queued_call + from modules_forge import main_thread - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_models.reload_model_weights)), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 439bd5ff..409385a8 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -511,6 +511,11 @@ def start(): else: webui.webui() + from modules_forge import main_thread + + main_thread.loop() + return + def dump_sysinfo(): from modules import sysinfo diff --git a/modules/shared_state.py b/modules/shared_state.py index 33996691..5da5c7a0 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -2,6 +2,8 @@ import datetime import logging import threading import time +import traceback +import torch from modules import errors, shared, devices from typing import Optional @@ -134,6 +136,7 @@ class State: devices.torch_gc() + @torch.inference_mode() def set_current_image(self): """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly""" if not shared.parallel_processing_allowed: @@ -142,6 +145,7 @@ class State: if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1: self.do_set_current_image() + @torch.inference_mode() def do_set_current_image(self): if self.current_latent is None: return @@ -156,11 +160,14 @@ class State: self.current_image_sampling_step = self.sampling_step - except Exception: + except Exception as e: + # traceback.print_exc() + # print(e) # when switching models during genration, VAE would be on CPU, so creating an image will fail. # we silently ignore this error errors.record_exception() + @torch.inference_mode() def assign_current_image(self, image): self.current_image = image self.id_live_preview += 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index 8582eddb..04d62a0a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,6 +9,7 @@ import modules.shared as shared from modules.ui import plaintext_to_html from PIL import Image import gradio as gr +from modules_forge import main_thread def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): @@ -56,7 +57,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne return p -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): +def txt2img_upscale_function(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): assert len(gallery) > 0, 'No image to upscale' assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' @@ -100,7 +101,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") -def txt2img(id_task: str, request: gr.Request, *args): +def txt2img_function(id_task: str, request: gr.Request, *args): p = txt2img_create_processing(id_task, request, *args) with closing(p): @@ -119,3 +120,11 @@ def txt2img(id_task: str, request: gr.Request, *args): processed.images = [] return processed.images + processed.extra_images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") + + +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): + return main_thread.run_and_wait_result(txt2img_upscale_function, id_task, request, gallery, gallery_index, generation_info, *args) + + +def txt2img(id_task: str, request: gr.Request, *args): + return main_thread.run_and_wait_result(txt2img_function, id_task, request, *args) diff --git a/modules_forge/main_thread.py b/modules_forge/main_thread.py new file mode 100644 index 00000000..ff30345a --- /dev/null +++ b/modules_forge/main_thread.py @@ -0,0 +1,68 @@ +# This file is the main thread that handles all gradio calls for major t2i or i2i processing. +# Other gradio calls (like those from extensions) are not influenced. +# By using one single thread to process all major calls, model moving is significantly faster. + + +import time +import traceback +import threading + + +lock = threading.Lock() +last_id = 0 +waiting_list = [] +finished_list = [] + + +class Task: + def __init__(self, task_id, func, args, kwargs): + self.task_id = task_id + self.func = func + self.args = args + self.kwargs = kwargs + self.result = None + + def work(self): + self.result = self.func(*self.args, **self.kwargs) + + +def loop(): + global lock, last_id, waiting_list, finished_list + while True: + time.sleep(0.01) + if len(waiting_list) > 0: + with lock: + task = waiting_list.pop(0) + try: + task.work() + except Exception as e: + traceback.print_exc() + print(e) + with lock: + finished_list.append(task) + + +def async_run(func, *args, **kwargs): + global lock, last_id, waiting_list, finished_list + with lock: + last_id += 1 + new_task = Task(task_id=last_id, func=func, args=args, kwargs=kwargs) + waiting_list.append(new_task) + return new_task.task_id + + +def run_and_wait_result(func, *args, **kwargs): + global lock, last_id, waiting_list, finished_list + current_id = async_run(func, *args, **kwargs) + while True: + time.sleep(0.01) + finished_task = None + for t in finished_list.copy(): # thread safe shallow copy without needing a lock + if t.task_id == current_id: + finished_task = t + break + if finished_task is not None: + with lock: + finished_list.remove(finished_task) + return finished_task.result + diff --git a/webui.py b/webui.py index e15cf452..022330b0 100644 --- a/webui.py +++ b/webui.py @@ -6,8 +6,10 @@ import time from modules import timer from modules import initialize_util from modules import initialize - +from threading import Thread from modules_forge.initialization import initialize_forge +from modules_forge import main_thread + startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -18,6 +20,8 @@ initialize.imports() initialize.check_versions() +initialize.initialize() + def create_api(app): from modules.api.api import Api @@ -27,12 +31,10 @@ def create_api(app): return api -def api_only(): +def api_only_worker(): from fastapi import FastAPI from modules.shared_cmd_options import cmd_opts - initialize.initialize() - app = FastAPI() initialize_util.setup_middleware(app) api = create_api(app) @@ -49,11 +51,10 @@ def api_only(): ) -def webui(): +def webui_worker(): from modules.shared_cmd_options import cmd_opts launch_api = cmd_opts.api - initialize.initialize() from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks @@ -157,6 +158,14 @@ def webui(): initialize.initialize_rest(reload_script_modules=True) +def api_only(): + Thread(target=api_only_worker, daemon=True).start() + + +def webui(): + Thread(target=webui_worker, daemon=True).start() + + if __name__ == "__main__": from modules.shared_cmd_options import cmd_opts @@ -164,3 +173,5 @@ if __name__ == "__main__": api_only() else: webui() + + main_thread.loop()