split shared.py into multiple files; should resolve all circular reference import errors related to shared.py

This commit is contained in:
AUTOMATIC1111 2023-08-09 10:25:35 +03:00
parent 7d81ecbea6
commit 386245a264
24 changed files with 764 additions and 1667 deletions

View File

@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache from functools import lru_cache
import torch import torch
from modules import errors from modules import errors, shared
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
@ -17,8 +17,6 @@ def has_mps() -> bool:
def get_cuda_device_string(): def get_cuda_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None: if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}" return f"cuda:{shared.cmd_opts.device_id}"
@ -40,8 +38,6 @@ def get_optimal_device():
def get_device_for(task): def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu: if task in shared.cmd_opts.use_cpu:
return cpu return cpu
@ -97,8 +93,6 @@ nv_rng = None
def autocast(disable=False): def autocast(disable=False):
from modules import shared
if disable: if disable:
return contextlib.nullcontext() return contextlib.nullcontext()
@ -117,8 +111,6 @@ class NansException(Exception):
def test_for_nans(x, where): def test_for_nans(x, where):
from modules import shared
if shared.cmd_opts.disable_nan_check: if shared.cmd_opts.disable_nan_check:
return return

View File

@ -1,7 +1,7 @@
import os import os
import threading import threading
from modules import shared, errors, cache from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@ -90,8 +90,6 @@ class Extension:
self.have_info_from_repo = True self.have_info_from_repo = True
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir) dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath): if not os.path.isdir(dirpath):
return [] return []

View File

@ -6,7 +6,7 @@ import re
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@ -198,7 +198,6 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512)) height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0: if firstpass_width == 0 or firstpass_height == 0:
from modules import processing
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width res['Size-1'] = firstpass_width

View File

@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file from modules.paths_internal import roboto_ttf_file
from modules.shared import opts from modules.shared import opts
import modules.sd_vae as sd_vae
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator: class FilenameGenerator:
def get_vae_filename(self): #get the name of the VAE file.
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
replacements = { replacements = {
'seed': lambda self: self.seed if self.seed is not None else '', 'seed': lambda self: self.seed if self.seed is not None else '',
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
@ -391,6 +379,22 @@ class FilenameGenerator:
self.image = image self.image = image
self.zip = zip self.zip = zip
def get_vae_filename(self):
"""Get the name of the VAE file."""
import modules.sd_vae as sd_vae
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
def hasprompt(self, *args): def hasprompt(self, *args):
lower = self.prompt.lower() lower = self.prompt.lower()
if self.p is None or self.prompt is None: if self.p is None or self.prompt is None:

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from modules import errors from modules import errors, scripts
localizations = {} localizations = {}
@ -16,7 +16,6 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file) localizations[fn] = os.path.join(dirname, file)
from modules import scripts
for file in scripts.list_scripts("localizations", ".json"): for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename) fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path localizations[fn] = file.path

View File

@ -4,6 +4,7 @@ import torch
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
from modules import shared
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -30,8 +31,7 @@ has_mps = check_for_mps()
def torch_mps_gc() -> None: def torch_mps_gc() -> None:
try: try:
from modules.shared import state if shared.state.current_latent is not None:
if state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection") log.debug("`current_latent` is set, skipping MPS garbage collection")
return return
from torch.mps import empty_cache from torch.mps import empty_cache

236
modules/options.py Normal file
View File

@ -0,0 +1,236 @@
import json
import sys
import gradio as gr
from modules import errors
from modules.shared_cmd_options import cmd_opts
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
self.refresh = refresh
self.do_not_save = False
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def html(self, html):
self.comment_after += html
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def needs_reload_ui(self):
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
return self
class OptionHTML(OptionInfo):
def __init__(self, text):
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
self.do_not_save = True
def options_section(section_identifier, options_dict):
for v in options_dict.values():
v.section = section_identifier
return options_dict
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
class Options:
typemap = {int: float}
def __init__(self, data_labels, restricted_opts):
self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()}
self.restricted_opts = restricted_opts
def __setattr__(self, key, value):
if key in options_builtin_fields:
return super(Options, self).__setattr__(key, value)
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = self.data_labels.get(key, None)
if info.do_not_save:
return
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if item in options_builtin_fields:
return super(Options, self).__getattribute__(item)
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
if self.data_labels[key].do_not_save:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
return True
type_x = self.typemap.get(type(x), type(x))
type_y = self.typemap.get(type(y), type(y))
return type_x == type_y
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
# 1.4.0 ui_reorder
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
if info is not None and not self.same_type(info.default, v):
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
bad_settings += 1
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value

View File

@ -63,9 +63,8 @@ def randn_without_seed(shape, generator=None):
def manual_seed(seed): def manual_seed(seed):
"""Set up a global random number generator using the specified seed.""" """Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV": if shared.opts.randn_source == "NV":
global nv_rng global nv_rng
nv_rng = rng_philox.Generator(seed) nv_rng = rng_philox.Generator(seed)
return return

View File

@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
@ -473,7 +473,6 @@ model_data = SdModelData()
def get_empty_cond(sd_model): def get_empty_cond(sd_model):
from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img() p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {}) extra_networks.activate(p, {})
@ -486,8 +485,6 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m): def send_model_to_cpu(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
@ -497,8 +494,6 @@ def send_model_to_cpu(m):
def send_model_to_device(m): def send_model_to_device(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else: else:
@ -642,7 +637,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
timer = Timer() timer = Timer()
@ -705,7 +699,6 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
timer = Timer() timer = Timer()
if model_data.sd_model: if model_data.sd_model:

View File

@ -2,7 +2,7 @@ import os
import torch import torch
from modules import shared, paths, sd_disable_initialization from modules import shared, paths, sd_disable_initialization, devices
sd_configs_path = shared.sd_configs_path sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
""" """
import ldm.modules.diffusionmodules.openaimodel import ldm.modules.diffusionmodules.openaimodel
from modules import devices
device = devices.cpu device = devices.cpu

View File

@ -2,7 +2,8 @@ import os
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack
import glob import glob
from copy import deepcopy from copy import deepcopy
@ -231,8 +232,6 @@ unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified): def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
import os
import launch
from modules import cmd_args, script_loading
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
parser = cmd_args.parser
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access

View File

@ -0,0 +1,66 @@
import os
import gradio as gr
from modules import errors, shared
from modules.paths_internal import script_path
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/base",
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/Lime",
"abidlabs/pakistan",
"Ama434/neutral-barlow",
"dawood/microsoft_windows",
"finlaymacklon/smooth_slate",
"Franklisi/darkmode",
"freddyaboulton/dracula_revamped",
"freddyaboulton/test-blue",
"gstaff/xkcd",
"Insuz/Mocha",
"Insuz/SimpleIndigo",
"JohnSmith9982/small_and_pretty",
"nota-ai/theme",
"nuttea/Softblue",
"ParityError/Anime",
"reilnuud/polite",
"remilia/Ghostly",
"rottenlittlecreature/Moon_Goblin",
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
"ysharma/steampunk"
]
def reload_gradio_theme(theme_name=None):
if not theme_name:
theme_name = shared.opts.gradio_theme
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
if theme_name == "Default":
shared.gradio_theme = gr.themes.Default(**default_theme_args)
else:
try:
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
else:
os.makedirs(theme_cache_dir, exist_ok=True)
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
gradio_theme.dump(theme_cache_path)
except Exception as e:
errors.display(e, "changing gradio theme")
shared.gradio_theme = gr.themes.Default(**default_theme_args)

51
modules/shared_init.py Normal file
View File

@ -0,0 +1,51 @@
import os
import torch
from modules import shared
from modules.shared import cmd_opts
import sys
sys.setrecursionlimit(1000)
def initialize():
"""Initializes fields inside the shared module in a controlled manner.
Should be called early because some other modules you can import mingt need these fields to be already set.
"""
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
from modules import options, shared_options
shared.options_templates = shared_options.options_templates
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
if os.path.exists(shared.config_filename):
shared.opts.load(shared.config_filename)
from modules import devices
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
from modules import shared_state
shared.state = shared_state.State()
from modules import styles
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
from modules import interrogate
shared.interrogator = interrogate.InterrogateModels("interrogate")
from modules import shared_total_tqdm
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
from modules import memmon, devices
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
shared.mem_mon.start()

View File

@ -1,3 +1,6 @@
import sys
from modules.shared_cmd_options import cmd_opts
def realesrgan_models_names(): def realesrgan_models_names():
@ -41,6 +44,28 @@ def refresh_unet_list():
modules.sd_unet.list_unets() modules.sd_unet.list_unets()
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
from modules import shared
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"inpaint", "inpaint",
"sampler", "sampler",
@ -67,3 +92,27 @@ def ui_reorder_categories():
yield from sections yield from sections
yield "scripts" yield "scripts"
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sys.modules['modules.shared'].__class__ = Shared

View File

@ -1,40 +1,12 @@
import datetime
import json
import os
import re
import sys
import threading
import time
import logging
import gradio as gr import gradio as gr
import torch
import tqdm
import launch from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args, rng # noqa: F401
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion from modules.shared_cmd_options import cmd_opts
from typing import Optional from modules.options import options_section, OptionInfo, OptionHTML
log = logging.getLogger(__name__)
demo = None
parser = cmd_args.parser
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
options_templates = {}
hide_dirs = shared.hide_dirs
restricted_opts = { restricted_opts = {
"samples_filename_pattern", "samples_filename_pattern",
@ -49,302 +21,6 @@ restricted_opts = {
"outdir_init_images" "outdir_init_images"
} }
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/base",
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/Lime",
"abidlabs/pakistan",
"Ama434/neutral-barlow",
"dawood/microsoft_windows",
"finlaymacklon/smooth_slate",
"Franklisi/darkmode",
"freddyaboulton/dracula_revamped",
"freddyaboulton/test-blue",
"gstaff/xkcd",
"Insuz/Mocha",
"Insuz/SimpleIndigo",
"JohnSmith9982/small_and_pretty",
"nota-ai/theme",
"nuttea/Softblue",
"ParityError/Anime",
"reilnuud/polite",
"remilia/Ghostly",
"rottenlittlecreature/Moon_Goblin",
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
"ysharma/steampunk"
]
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
loaded_hypernetworks = []
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
time_start = None
server_start = None
_server_command_signal = threading.Event()
_server_command: Optional[str] = None
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
log.info("Received restart request")
def skip(self):
self.skipped = True
log.info("Received skip request")
def interrupt(self):
self.interrupted = True
log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
devices.torch_gc()
def set_current_image(self):
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
if self.current_latent is None:
return
import modules.sd_samplers
try:
if opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
except Exception:
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
state = State()
state.server_start = time.time()
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
self.refresh = refresh
self.do_not_save = False
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def html(self, html):
self.comment_after += html
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def needs_reload_ui(self):
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
return self
class OptionHTML(OptionInfo):
def __init__(self, text):
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
self.do_not_save = True
def options_section(section_identifier, options_dict):
for v in options_dict.values():
v.section = section_identifier
return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
tab_names = []
options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), { options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"), "samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'), "samples_format": OptionInfo('png', 'File format for images'),
@ -412,11 +88,11 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
})) }))
@ -450,7 +126,7 @@ options_templates.update(options_section(('training', "Training"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
@ -526,7 +202,7 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"), "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"), "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
@ -546,12 +222,12 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@ -568,9 +244,9 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
@ -605,7 +281,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"), "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
@ -638,339 +314,3 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))
options_templates.update()
class Options:
data = None
data_labels = options_templates
typemap = {int: float}
def __init__(self):
self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = opts.data_labels.get(key, None)
if info.do_not_save:
return
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
if self.data_labels[key].do_not_save:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
return True
type_x = self.typemap.get(type(x), type(x))
type_y = self.typemap.get(type(y), type(y))
return type_x == type_y
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
# 1.4.0 ui_reorder
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
if info is not None and not self.same_type(info.default, v):
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
bad_settings += 1
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
"Latent (antialiased)": {"mode": "bilinear", "antialias": True},
"Latent (bicubic)": {"mode": "bicubic", "antialias": False},
"Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
"Latent (nearest)": {"mode": "nearest", "antialias": False},
"Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
}
sd_upscalers = []
clip_model = None
progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
def reload_gradio_theme(theme_name=None):
global gradio_theme
if not theme_name:
theme_name = opts.gradio_theme
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
if theme_name == "Default":
gradio_theme = gr.themes.Default(**default_theme_args)
else:
try:
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
if opts.gradio_themes_cache and os.path.exists(theme_cache_path):
gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
else:
os.makedirs(theme_cache_dir, exist_ok=True)
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
gradio_theme.dump(theme_cache_path)
except Exception as e:
errors.display(e, "changing gradio theme")
gradio_theme = gr.themes.Default(**default_theme_args)
class TotalTQDM:
def __init__(self):
self._tqdm = None
def reset(self):
self._tqdm = tqdm.tqdm(
desc="Total progress",
total=state.job_count * state.sampling_steps,
position=1,
file=progress_print_out
)
def update(self):
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
def html_path(filename):
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
items = list(os.walk(path, followlinks=True))
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
for root, _, files in items:
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
if not opts.list_hidden_files and ("/." in root or "\\." in root):
continue
yield os.path.join(root, filename)
def ldm_print(*args, **kwargs):
if opts.hide_ldm_prints:
return
print(*args, **kwargs)

159
modules/shared_state.py Normal file
View File

@ -0,0 +1,159 @@
import datetime
import logging
import threading
import time
from modules import errors, shared, devices
from typing import Optional
log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
time_start = None
server_start = None
_server_command_signal = threading.Event()
_server_command: Optional[str] = None
def __init__(self):
self.server_start = time.time()
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
log.info("Received restart request")
def skip(self):
self.skipped = True
log.info("Received skip request")
def interrupt(self):
self.interrupted = True
log.info("Received interrupt request")
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
devices.torch_gc()
def set_current_image(self):
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not shared.parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
if self.current_latent is None:
return
import modules.sd_samplers
try:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
except Exception:
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1

View File

@ -0,0 +1,37 @@
import tqdm
from modules import shared
class TotalTQDM:
def __init__(self):
self._tqdm = None
def reset(self):
self._tqdm = tqdm.tqdm(
desc="Total progress",
total=shared.state.job_count * shared.state.sampling_steps,
position=1,
file=shared.progress_print_out
)
def update(self):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None

View File

@ -10,7 +10,7 @@ import psutil
import re import re
import launch import launch
from modules import paths_internal, timer from modules import paths_internal, timer, shared, extensions, errors
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY" checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = { environment_whitelist = {
@ -115,8 +115,6 @@ def format_exception(e, tb):
def get_exceptions(): def get_exceptions():
try: try:
from modules import errors
return list(reversed(errors.exception_records)) return list(reversed(errors.exception_records))
except Exception as e: except Exception as e:
return str(e) return str(e)
@ -142,8 +140,6 @@ def get_torch_sysinfo():
def get_extensions(*, enabled): def get_extensions(*, enabled):
try: try:
from modules import extensions
def to_json(x: extensions.Extension): def to_json(x: extensions.Extension):
return { return {
"name": x.name, "name": x.name,
@ -160,7 +156,6 @@ def get_extensions(*, enabled):
def get_config(): def get_config():
try: try:
from modules import shared
return shared.opts.data return shared.opts.data
except Exception as e: except Exception as e:
return str(e) return str(e)

View File

@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401 from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, devices, ui_extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path from modules.paths import script_path
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
@ -91,8 +91,6 @@ def send_gradio_gallery_to_image(x):
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
if not enable: if not enable:
return "" return ""
@ -630,7 +628,6 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@ -995,7 +992,6 @@ def create_ui():
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
)) ))
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)

View File

@ -11,7 +11,7 @@ from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
import modules.images import modules.images
from modules.ui_components import ToolButton from modules.ui_components import ToolButton
import modules.generation_parameters_copypaste as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
@ -105,8 +105,6 @@ def save_files(js_data, images, do_make_zip, index):
def create_output_panel(tabname, outdir): def create_output_panel(tabname, outdir):
from modules import shared
import modules.generation_parameters_copypaste as parameters_copypaste
def open_folder(f): def open_folder(f):
if not os.path.exists(f): if not os.path.exists(f):

58
modules/util.py Normal file
View File

@ -0,0 +1,58 @@
import os
import re
from modules import shared
from modules.paths_internal import script_path
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
def html_path(filename):
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
items = list(os.walk(path, followlinks=True))
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
for root, _, files in items:
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
continue
yield os.path.join(root, filename)
def ldm_print(*args, **kwargs):
if shared.opts.hide_ldm_prints:
return
print(*args, **kwargs)

View File

@ -43,12 +43,15 @@ startup_timer.record("import torch")
import gradio # noqa: F401 import gradio # noqa: F401
startup_timer.record("import gradio") startup_timer.record("import gradio")
from modules import paths, timer, import_hook, errors, devices # noqa: F401 from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths") startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401 import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm") startup_timer.record("import ldm")
from modules import shared_init, shared, shared_items
shared_init.initialize()
startup_timer.record("initialize shared")
from modules import extra_networks from modules import extra_networks
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
@ -58,8 +61,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__ torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared
if not shared.cmd_opts.skip_version_check: if not shared.cmd_opts.skip_version_check:
errors.check_versions() errors.check_versions()
@ -82,7 +83,7 @@ import modules.textual_inversion.textual_inversion
import modules.progress import modules.progress
import modules.ui import modules.ui
from modules import modelloader from modules import modelloader, devices
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
@ -297,7 +298,7 @@ def initialize_rest(*, reload_script_modules=False):
Thread(target=load_model).start() Thread(target=load_model).start()
shared.reload_hypernetworks() shared_items.reload_hypernetworks()
startup_timer.record("reload hypernetworks") startup_timer.record("reload hypernetworks")
ui_extra_networks.initialize() ui_extra_networks.initialize()