diff --git a/javascript/extensions.js b/javascript/extensions.js index 312131b7..cc8ee220 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -2,8 +2,11 @@ function extensions_apply(_disabled_list, _update_list, disable_all) { var disable = []; var update = []; - - gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) { + const extensions_input = gradioApp().querySelectorAll('#extensions input[type="checkbox"]'); + if (extensions_input.length == 0) { + throw Error("Extensions page not yet loaded."); + } + extensions_input.forEach(function(x) { if (x.name.startsWith("enable_") && !x.checked) { disable.push(x.name.substring(7)); } diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 7e9cceb0..d21d3333 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -358,6 +358,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": res["Cache FP16 weight for LoRA"] = False + if "Emphasis" not in res: + res["Emphasis"] = "Original" + infotext_versions.backcompat(res) for key in skip_fields: diff --git a/modules/processing.py b/modules/processing.py index f477c9b1..10d0ea5e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -458,6 +458,7 @@ class StableDiffusionProcessing: self.height, opts.fp8_storage, opts.cache_fp16_weight, + opts.emphasis, ) def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): diff --git a/modules/processing_scripts/comments.py b/modules/processing_scripts/comments.py new file mode 100644 index 00000000..638e39f2 --- /dev/null +++ b/modules/processing_scripts/comments.py @@ -0,0 +1,42 @@ +from modules import scripts, shared, script_callbacks +import re + + +def strip_comments(text): + text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # while line comment + text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment + + return text + + +class ScriptStripComments(scripts.Script): + def title(self): + return "Comments" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def process(self, p, *args): + if not shared.opts.enable_prompt_comments: + return + + p.all_prompts = [strip_comments(x) for x in p.all_prompts] + p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts] + + p.main_prompt = strip_comments(p.main_prompt) + p.main_negative_prompt = strip_comments(p.main_negative_prompt) + + +def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): + if not shared.opts.enable_prompt_comments: + return + + params.prompt = strip_comments(params.prompt) + + +script_callbacks.on_before_token_counter(before_token_counter) + + +shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), { + "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."), +})) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 48ab289c..2c50f43c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ +import dataclasses import inspect import os from collections import namedtuple @@ -106,6 +107,15 @@ class ImageGridLoopParams: self.rows = rows +@dataclasses.dataclass +class BeforeTokenCounterParams: + prompt: str + steps: int + styles: list + + is_positive: bool = True + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -128,6 +138,7 @@ callback_map = dict( callbacks_on_reload=[], callbacks_list_optimizers=[], callbacks_list_unets=[], + callbacks_before_token_counter=[], ) event_subscriber_map = dict( callbacks_setting_updated=[], @@ -312,6 +323,14 @@ def list_unets_callback(): return res +def before_token_counter_callback(params: BeforeTokenCounterParams): + for c in callback_map['callbacks_before_token_counter']: + try: + c.callback(params) + except Exception: + report_exception(c, 'before_token_counter') + + def setting_updated_event_subscriber_chain(handler, component, setting_name: str): """ Arguments: @@ -505,6 +524,13 @@ def on_list_unets(callback): add_callback(callback_map['callbacks_list_unets'], callback) +def on_before_token_counter(callback): + """register a function to be called when UI is counting tokens for a prompt. + The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" + + add_callback(callback_map['callbacks_before_token_counter'], callback) + + def on_setting_updated_subscriber(subscriber_params): """register a function to be called after settings update. `subscriber_params` should contain necessary fields to register an gradio event handler. Necessary @@ -513,3 +539,4 @@ def on_setting_updated_subscriber(subscriber_params): sure to handle these extra params when defining the callback function. """ event_subscriber_map['callbacks_setting_updated'].append(subscriber_params) + diff --git a/modules/sd_emphasis.py b/modules/sd_emphasis.py new file mode 100644 index 00000000..654817b6 --- /dev/null +++ b/modules/sd_emphasis.py @@ -0,0 +1,70 @@ +from __future__ import annotations +import torch + + +class Emphasis: + """Emphasis class decides how to death with (emphasized:1.1) text in prompts""" + + name: str = "Base" + description: str = "" + + tokens: list[list[int]] + """tokens from the chunk of the prompt""" + + multipliers: torch.Tensor + """tensor with multipliers, once for each token""" + + z: torch.Tensor + """output of cond transformers network (CLIP)""" + + def after_transformers(self): + """Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis""" + + pass + + +class EmphasisNone(Emphasis): + name = "None" + description = "disable the mechanism entirely and treat (:.1.1) as literal characters" + + +class EmphasisIgnore(Emphasis): + name = "Ignore" + description = "treat all empasised words as if they have no emphasis" + + +class EmphasisOriginal(Emphasis): + name = "Original" + description = "the orginal emphasis implementation" + + def after_transformers(self): + original_mean = self.z.mean() + self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + new_mean = self.z.mean() + self.z = self.z * (original_mean / new_mean) + + +class EmphasisOriginalNoNorm(EmphasisOriginal): + name = "No norm" + description = "same as orginal, but without normalization (seems to work better for SDXL)" + + def after_transformers(self): + self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) + + +def get_current_option(emphasis_option_name): + return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal) + + +def get_options_descriptions(): + return ", ".join(f"{x.name}: {x.description}" for x in options) + + +options = [ + EmphasisNone, + EmphasisIgnore, + EmphasisOriginal, + EmphasisOriginalNoNorm, +] diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 8f29057a..98350ac4 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices, sd_hijack +from modules import prompt_parser, devices, sd_hijack, sd_emphasis from modules.shared import opts @@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): Returns the list and the total number of tokens in the prompt. """ - if opts.enable_emphasis: + if opts.emphasis != "None": parsed = prompt_parser.parse_prompt_attention(line) else: parsed = [[line, 1.0]] @@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): hashes.append(self.hijack.extra_generation_params.get("TI hashes")) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": + self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + if getattr(self.wrapped, 'return_pooled', False): return torch.hstack(zs), zs[0].pooled else: @@ -274,12 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): pooled = getattr(z, 'pooled', None) - # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) - original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) - new_mean = z.mean() - z = z * (original_mean / new_mean) + emphasis = sd_emphasis.get_current_option(opts.emphasis)() + emphasis.tokens = remade_batch_tokens + emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device) + emphasis.z = z + + emphasis.after_transformers() + + z = emphasis.z if pooled is not None: z.pooled = pooled diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index c5c6270b..43e9b952 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -32,7 +32,7 @@ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None if mult_change is not None: mult *= mult_change i += 1 diff --git a/modules/shared_options.py b/modules/shared_options.py index 0d3f518e..de5deada 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -1,7 +1,7 @@ import os import gradio as gr -from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util, sd_emphasis from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts from modules.options import options_section, OptionInfo, OptionHTML, categories @@ -154,7 +154,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), { "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(), - "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "emphasis": OptionInfo("Original", "Emphasis mode", gr.Radio, lambda: {"choices": [x.name for x in sd_emphasis.options]}, infotext="Emphasis").info("makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; " + sd_emphasis.get_options_descriptions()), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), @@ -271,6 +271,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), + "include_styles_into_token_counters": OptionInfo(True, "Count tokens of enabled styles").info("When calculating how many tokens the prompt has, also consider tokens added by enabled styles."), })) options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { diff --git a/modules/ui.py b/modules/ui.py index 5744f192..81655525 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -151,7 +151,18 @@ def connect_clear_prompt(button): ) -def update_token_counter(text, steps, *, is_positive=True): +def update_token_counter(text, steps, styles, *, is_positive=True): + params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive) + script_callbacks.before_token_counter_callback(params) + text = params.prompt + steps = params.steps + styles = params.styles + is_positive = params.is_positive + + if shared.opts.include_styles_into_token_counters: + apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt + text = apply_styles(text, styles) + try: text, _ = extra_networks.parse_prompt(text) @@ -179,8 +190,8 @@ def update_token_counter(text, steps, *, is_positive=True): return f"{token_count}/{max_length}" -def update_negative_prompt_token_counter(text, steps): - return update_token_counter(text, steps, is_positive=False) +def update_negative_prompt_token_counter(*args): + return update_token_counter(*args, is_positive=False) def setup_progressbar(*args, **kwargs): @@ -491,8 +502,10 @@ def create_ui(): height, ] - toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) + toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery) @@ -830,8 +843,10 @@ def create_ui(): **interrogate_args, ) - toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) - toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) + toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter]) img2img_paste_fields = [ (toprow.prompt, "Prompt"), @@ -869,7 +884,7 @@ def create_ui(): ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row(equal_height=False): + with ResizeHandleRow(equal_height=False): with gr.Column(variant='panel'): image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") @@ -897,7 +912,7 @@ def create_ui(): with gr.Row(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") - with gr.Row(variant="compact", equal_height=False): + with ResizeHandleRow(variant="compact", equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding", id="create_embedding"): diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 85015db5..91f40ea4 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -35,12 +35,7 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): register_tmp_file(shared.demo, already_saved_as) - filename = already_saved_as - - if not shared.opts.save_images_add_number: - filename += f'?{os.path.getmtime(already_saved_as)}' - - return filename + return f'{already_saved_as}?{os.path.getmtime(already_saved_as)}' if shared.opts.temp_dir != "": dir = shared.opts.temp_dir diff --git a/modules/util.py b/modules/util.py index ee373e92..8d1aea44 100644 --- a/modules/util.py +++ b/modules/util.py @@ -42,7 +42,7 @@ def walk_files(path, allowed_extensions=None): for filename in sorted(files, key=natural_sort_key): if allowed_extensions is not None: _, ext = os.path.splitext(filename) - if ext not in allowed_extensions: + if ext.lower() not in allowed_extensions: continue if not shared.opts.list_hidden_files and ("/." in root or "\\." in root): diff --git a/style.css b/style.css index 04b0a2fa..988c28c0 100644 --- a/style.css +++ b/style.css @@ -846,6 +846,20 @@ table.popup-table .link{ display: inline-block; } +/* extensions tab table row hover highlight */ + +#extensions tr:hover td, +#config_state_extensions tr:hover td, +#available_extensions tr:hover td { + background: rgba(0, 0, 0, 0.15) +} + +.dark #extensions tr:hover td , +.dark #config_state_extensions tr:hover td , +.dark #available_extensions tr:hover td { + background: rgba(255, 255, 255, 0.15) +} + /* replace original footer with ours */ footer { diff --git a/webui.sh b/webui.sh index 25b94906..794cfb8a 100755 --- a/webui.sh +++ b/webui.sh @@ -226,30 +226,48 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then + # check glibc version + LIBC_LIB="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P "libc.so.6" | head -n 1)" + LIBC_INFO=$(echo ${LIBC_LIB} | awk '{print $NF}') + LIBC_VER=$(echo $(${LIBC_INFO} | awk 'NR==1 {print $NF}') | grep -oP '\d+\.\d+') + echo "glibc version is $LIBC_VER" + libc_vernum=$(expr $LIBC_VER) + # Since 2.34 libpthread is integrated into libc.so + libc_v234=2.34 # Define Tcmalloc Libs arrays TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d") - # Traversal array for lib in "${TCMALLOC_LIBS[@]}" do - #Determine which type of tcmalloc library the library supports - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" - TC_INFO=(${TCMALLOC//=>/}) - if [[ ! -z "${TC_INFO}" ]]; then - echo "Using TCMalloc: ${TC_INFO}" - #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create - if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then - echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}" - export LD_PRELOAD="${TC_INFO}" - break - else - echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error" - fi - else - printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" - fi + # Determine which type of tcmalloc library the library supports + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TC_INFO=(${TCMALLOC//=>/}) + if [[ ! -z "${TC_INFO}" ]]; then + echo "Check TCMalloc: ${TC_INFO}" + # Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_key_create + if [ $(echo "$libc_vernum < $libc_v234" | bc) -eq 1 ]; then + # glibc < 2.33 pthread_key_create into libpthead.so. check linking libpthread.so... + if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then + echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO[2]}" + # set fullpath LD_PRELOAD (To be on the safe side) + export LD_PRELOAD="${TC_INFO[2]}" + break + else + echo "$TC_INFO is not linked with libpthread will trigger undefined symbol: pthread_Key_create error" + fi + else + # Version 2.34 of libc.so (glibc) includes the pthead library IN GLIBC. (USE ubuntu 22.04 and modern linux system and WSL) + # libc.so(glibc) is linked with a library that works in ALMOST ALL Linux userlands. SO NO CHECK! + echo "$TC_INFO is linked with libc.so,execute LD_PRELOAD=${TC_INFO[2]}" + # set fullpath LD_PRELOAD (To be on the safe side) + export LD_PRELOAD="${TC_INFO[2]}" + break + fi + fi done - + if [[ -z "${LD_PRELOAD}" ]]; then + printf "\e[1m\e[31mCannot locate TCMalloc. Do you have tcmalloc or gperftools installed on your system? (improves CPU memory usage)\e[0m\n" + fi fi }