upstream dev commits

This commit is contained in:
lllyasviel 2024-02-11 16:45:07 -08:00 committed by GitHub
commit bc5589b249
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 241 additions and 47 deletions

View File

@ -2,8 +2,11 @@
function extensions_apply(_disabled_list, _update_list, disable_all) {
var disable = [];
var update = [];
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
const extensions_input = gradioApp().querySelectorAll('#extensions input[type="checkbox"]');
if (extensions_input.length == 0) {
throw Error("Extensions page not yet loaded.");
}
extensions_input.forEach(function(x) {
if (x.name.startsWith("enable_") && !x.checked) {
disable.push(x.name.substring(7));
}

View File

@ -358,6 +358,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
res["Cache FP16 weight for LoRA"] = False
if "Emphasis" not in res:
res["Emphasis"] = "Original"
infotext_versions.backcompat(res)
for key in skip_fields:

View File

@ -458,6 +458,7 @@ class StableDiffusionProcessing:
self.height,
opts.fp8_storage,
opts.cache_fp16_weight,
opts.emphasis,
)
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):

View File

@ -0,0 +1,42 @@
from modules import scripts, shared, script_callbacks
import re
def strip_comments(text):
text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # while line comment
text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment
return text
class ScriptStripComments(scripts.Script):
def title(self):
return "Comments"
def show(self, is_img2img):
return scripts.AlwaysVisible
def process(self, p, *args):
if not shared.opts.enable_prompt_comments:
return
p.all_prompts = [strip_comments(x) for x in p.all_prompts]
p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts]
p.main_prompt = strip_comments(p.main_prompt)
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
if not shared.opts.enable_prompt_comments:
return
params.prompt = strip_comments(params.prompt)
script_callbacks.on_before_token_counter(before_token_counter)
shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
"enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
}))

View File

@ -1,3 +1,4 @@
import dataclasses
import inspect
import os
from collections import namedtuple
@ -106,6 +107,15 @@ class ImageGridLoopParams:
self.rows = rows
@dataclasses.dataclass
class BeforeTokenCounterParams:
prompt: str
steps: int
styles: list
is_positive: bool = True
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
@ -128,6 +138,7 @@ callback_map = dict(
callbacks_on_reload=[],
callbacks_list_optimizers=[],
callbacks_list_unets=[],
callbacks_before_token_counter=[],
)
event_subscriber_map = dict(
callbacks_setting_updated=[],
@ -312,6 +323,14 @@ def list_unets_callback():
return res
def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']:
try:
c.callback(params)
except Exception:
report_exception(c, 'before_token_counter')
def setting_updated_event_subscriber_chain(handler, component, setting_name: str):
"""
Arguments:
@ -505,6 +524,13 @@ def on_list_unets(callback):
add_callback(callback_map['callbacks_list_unets'], callback)
def on_before_token_counter(callback):
"""register a function to be called when UI is counting tokens for a prompt.
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
add_callback(callback_map['callbacks_before_token_counter'], callback)
def on_setting_updated_subscriber(subscriber_params):
"""register a function to be called after settings update. `subscriber_params`
should contain necessary fields to register an gradio event handler. Necessary
@ -513,3 +539,4 @@ def on_setting_updated_subscriber(subscriber_params):
sure to handle these extra params when defining the callback function.
"""
event_subscriber_map['callbacks_setting_updated'].append(subscriber_params)

70
modules/sd_emphasis.py Normal file
View File

@ -0,0 +1,70 @@
from __future__ import annotations
import torch
class Emphasis:
"""Emphasis class decides how to death with (emphasized:1.1) text in prompts"""
name: str = "Base"
description: str = ""
tokens: list[list[int]]
"""tokens from the chunk of the prompt"""
multipliers: torch.Tensor
"""tensor with multipliers, once for each token"""
z: torch.Tensor
"""output of cond transformers network (CLIP)"""
def after_transformers(self):
"""Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis"""
pass
class EmphasisNone(Emphasis):
name = "None"
description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
class EmphasisIgnore(Emphasis):
name = "Ignore"
description = "treat all empasised words as if they have no emphasis"
class EmphasisOriginal(Emphasis):
name = "Original"
description = "the orginal emphasis implementation"
def after_transformers(self):
original_mean = self.z.mean()
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
new_mean = self.z.mean()
self.z = self.z * (original_mean / new_mean)
class EmphasisOriginalNoNorm(EmphasisOriginal):
name = "No norm"
description = "same as orginal, but without normalization (seems to work better for SDXL)"
def after_transformers(self):
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
def get_current_option(emphasis_option_name):
return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
def get_options_descriptions():
return ", ".join(f"{x.name}: {x.description}" for x in options)
options = [
EmphasisNone,
EmphasisIgnore,
EmphasisOriginal,
EmphasisOriginalNoNorm,
]

View File

@ -3,7 +3,7 @@ from collections import namedtuple
import torch
from modules import prompt_parser, devices, sd_hijack
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules.shared import opts
@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
Returns the list and the total number of tokens in the prompt.
"""
if opts.enable_emphasis:
if opts.emphasis != "None":
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if getattr(self.wrapped, 'return_pooled', False):
return torch.hstack(zs), zs[0].pooled
else:
@ -274,12 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
pooled = getattr(z, 'pooled', None)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z = z * (original_mean / new_mean)
emphasis = sd_emphasis.get_current_option(opts.emphasis)()
emphasis.tokens = remade_batch_tokens
emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
emphasis.z = z
emphasis.after_transformers()
z = emphasis.z
if pooled is not None:
z.pooled = pooled

View File

@ -32,7 +32,7 @@ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase,
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
if mult_change is not None:
mult *= mult_change
i += 1

View File

@ -1,7 +1,7 @@
import os
import gradio as gr
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util, sd_emphasis
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML, categories
@ -154,7 +154,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"emphasis": OptionInfo("Original", "Emphasis mode", gr.Radio, lambda: {"choices": [x.name for x in sd_emphasis.options]}, infotext="Emphasis").info("makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; " + sd_emphasis.get_options_descriptions()),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
@ -271,6 +271,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing",
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
"include_styles_into_token_counters": OptionInfo(True, "Count tokens of enabled styles").info("When calculating how many tokens the prompt has, also consider tokens added by enabled styles."),
}))
options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {

View File

@ -151,7 +151,18 @@ def connect_clear_prompt(button):
)
def update_token_counter(text, steps, *, is_positive=True):
def update_token_counter(text, steps, styles, *, is_positive=True):
params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)
script_callbacks.before_token_counter_callback(params)
text = params.prompt
steps = params.steps
styles = params.styles
is_positive = params.is_positive
if shared.opts.include_styles_into_token_counters:
apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt
text = apply_styles(text, styles)
try:
text, _ = extra_networks.parse_prompt(text)
@ -179,8 +190,8 @@ def update_token_counter(text, steps, *, is_positive=True):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
def update_negative_prompt_token_counter(text, steps):
return update_token_counter(text, steps, is_positive=False)
def update_negative_prompt_token_counter(*args):
return update_token_counter(*args, is_positive=False)
def setup_progressbar(*args, **kwargs):
@ -491,8 +502,10 @@ def create_ui():
height,
]
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
@ -830,8 +843,10 @@ def create_ui():
**interrogate_args,
)
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
img2img_paste_fields = [
(toprow.prompt, "Prompt"),
@ -869,7 +884,7 @@ def create_ui():
ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row(equal_height=False):
with ResizeHandleRow(equal_height=False):
with gr.Column(variant='panel'):
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
@ -897,7 +912,7 @@ def create_ui():
with gr.Row(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
with gr.Row(variant="compact", equal_height=False):
with ResizeHandleRow(variant="compact", equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding", id="create_embedding"):

View File

@ -35,12 +35,7 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
filename = already_saved_as
if not shared.opts.save_images_add_number:
filename += f'?{os.path.getmtime(already_saved_as)}'
return filename
return f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir

View File

@ -42,7 +42,7 @@ def walk_files(path, allowed_extensions=None):
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
if ext.lower() not in allowed_extensions:
continue
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):

View File

@ -846,6 +846,20 @@ table.popup-table .link{
display: inline-block;
}
/* extensions tab table row hover highlight */
#extensions tr:hover td,
#config_state_extensions tr:hover td,
#available_extensions tr:hover td {
background: rgba(0, 0, 0, 0.15)
}
.dark #extensions tr:hover td ,
.dark #config_state_extensions tr:hover td ,
.dark #available_extensions tr:hover td {
background: rgba(255, 255, 255, 0.15)
}
/* replace original footer with ours */
footer {

View File

@ -226,30 +226,48 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
# check glibc version
LIBC_LIB="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P "libc.so.6" | head -n 1)"
LIBC_INFO=$(echo ${LIBC_LIB} | awk '{print $NF}')
LIBC_VER=$(echo $(${LIBC_INFO} | awk 'NR==1 {print $NF}') | grep -oP '\d+\.\d+')
echo "glibc version is $LIBC_VER"
libc_vernum=$(expr $LIBC_VER)
# Since 2.34 libpthread is integrated into libc.so
libc_v234=2.34
# Define Tcmalloc Libs arrays
TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d")
# Traversal array
for lib in "${TCMALLOC_LIBS[@]}"
do
#Determine which type of tcmalloc library the library supports
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)"
TC_INFO=(${TCMALLOC//=>/})
if [[ ! -z "${TC_INFO}" ]]; then
echo "Using TCMalloc: ${TC_INFO}"
#Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create
if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then
echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}"
export LD_PRELOAD="${TC_INFO}"
break
else
echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error"
fi
else
printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
fi
# Determine which type of tcmalloc library the library supports
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)"
TC_INFO=(${TCMALLOC//=>/})
if [[ ! -z "${TC_INFO}" ]]; then
echo "Check TCMalloc: ${TC_INFO}"
# Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_key_create
if [ $(echo "$libc_vernum < $libc_v234" | bc) -eq 1 ]; then
# glibc < 2.33 pthread_key_create into libpthead.so. check linking libpthread.so...
if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then
echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO[2]}"
# set fullpath LD_PRELOAD (To be on the safe side)
export LD_PRELOAD="${TC_INFO[2]}"
break
else
echo "$TC_INFO is not linked with libpthread will trigger undefined symbol: pthread_Key_create error"
fi
else
# Version 2.34 of libc.so (glibc) includes the pthead library IN GLIBC. (USE ubuntu 22.04 and modern linux system and WSL)
# libc.so(glibc) is linked with a library that works in ALMOST ALL Linux userlands. SO NO CHECK!
echo "$TC_INFO is linked with libc.so,execute LD_PRELOAD=${TC_INFO[2]}"
# set fullpath LD_PRELOAD (To be on the safe side)
export LD_PRELOAD="${TC_INFO[2]}"
break
fi
fi
done
if [[ -z "${LD_PRELOAD}" ]]; then
printf "\e[1m\e[31mCannot locate TCMalloc. Do you have tcmalloc or gperftools installed on your system? (improves CPU memory usage)\e[0m\n"
fi
fi
}