Merge remote-tracking branch 'origin/dev' into soft-inpainting

# Conflicts:
#	modules/processing.py
This commit is contained in:
CodeHatchling 2023-12-02 21:14:02 -07:00
commit 3bd3a09160
92 changed files with 2546 additions and 694 deletions

View File

@ -74,6 +74,7 @@ module.exports = {
create_submit_args: "readonly", create_submit_args: "readonly",
restart_reload: "readonly", restart_reload: "readonly",
updateInput: "readonly", updateInput: "readonly",
onEdit: "readonly",
//extraNetworks.js //extraNetworks.js
requestGet: "readonly", requestGet: "readonly",
popup: "readonly", popup: "readonly",

View File

@ -1,25 +1,45 @@
name: Bug Report name: Bug Report
description: You think somethings is broken in the UI description: You think something is broken in the UI
title: "[Bug]: " title: "[Bug]: "
labels: ["bug-report"] labels: ["bug-report"]
body: body:
- type: checkboxes
attributes:
label: Is there an existing issue for this?
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
options:
- label: I have searched the existing issues and checked the recent builds/commits
required: true
- type: markdown - type: markdown
attributes: attributes:
value: | value: |
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** > The title of the bug report should be short and descriptive.
> Use relevant keywords for searchability.
> Do not leave it blank, but also do not put an entire error log in it.
- type: checkboxes
attributes:
label: Checklist
description: |
Please perform basic debugging to see if extensions or configuration is the cause of the issue.
Basic debug procedure
 1. Disable all third-party extensions - check if extension is the cause
 2. Update extensions and webui - sometimes things just need to be updated
 3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
 4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
 5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
Before making a issue report please, check that the issue hasn't been reported recently.
options:
- label: The issue exists after disabling all extensions
- label: The issue exists on a clean installation of webui
- label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
- label: The issue exists in the current version of the webui
- label: The issue has not been reported before recently
- label: The issue has been reported before but has not been fixed yet
- type: markdown
attributes:
value: |
> Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
- type: textarea - type: textarea
id: what-did id: what-did
attributes: attributes:
label: What happened? label: What happened?
description: Tell us what happened in a very clear and simple way description: Tell us what happened in a very clear and simple way
placeholder: |
txt2img is not working as intended.
validations: validations:
required: true required: true
- type: textarea - type: textarea
@ -27,9 +47,9 @@ body:
attributes: attributes:
label: Steps to reproduce the problem label: Steps to reproduce the problem
description: Please provide us with precise step by step instructions on how to reproduce the bug description: Please provide us with precise step by step instructions on how to reproduce the bug
value: | placeholder: |
1. Go to .... 1. Go to ...
2. Press .... 2. Press ...
3. ... 3. ...
validations: validations:
required: true required: true
@ -38,13 +58,8 @@ body:
attributes: attributes:
label: What should have happened? label: What should have happened?
description: Tell us what you think the normal behavior should be description: Tell us what you think the normal behavior should be
validations: placeholder: |
required: true WebUI should ...
- type: textarea
id: sysinfo
attributes:
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
validations: validations:
required: true required: true
- type: dropdown - type: dropdown
@ -58,12 +73,25 @@ body:
- Brave - Brave
- Apple Safari - Apple Safari
- Microsoft Edge - Microsoft Edge
- Android
- iOS
- Other - Other
- type: textarea
id: sysinfo
attributes:
label: Sysinfo
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
placeholder: |
1. Go to WebUI Settings -> Sysinfo -> Download system info.
If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
validations:
required: true
- type: textarea - type: textarea
id: logs id: logs
attributes: attributes:
label: Console logs label: Console logs
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
render: Shell render: Shell
validations: validations:
required: true required: true
@ -71,4 +99,7 @@ body:
id: misc id: misc
attributes: attributes:
label: Additional information label: Additional information
description: Please provide us with any relevant additional info or context. description: |
Please provide us with any relevant additional info or context.
Examples:
 I have updated my GPU driver recently.

View File

@ -20,7 +20,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.272 run: pip install ruff==0.1.6
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:

View File

@ -88,9 +88,10 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters! - Now without any bad letters!
- Load checkpoints in safetensors format - Load checkpoints in safetensors format
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64 - Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
- Now with a license! - Now with a license!
- Reorder elements in the UI from settings screen - Reorder elements in the UI from settings screen
- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
@ -103,7 +104,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Installation on Windows 10/11 with NVidia-GPUs using release package ### Installation on Windows 10/11 with NVidia-GPUs using release package
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents. 1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents.
2. Run `update.bat`. 2. Run `update.bat`.
3. Run `run.bat`. 3. Run `run.bat`.
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) > For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab):
# Debian-based: # Debian-based:
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based: # Red Hat-based:
sudo dnf install wget git python3 sudo dnf install wget git python3 gperftools-libs libglvnd-glx
# openSUSE-based:
sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based: # Arch-based:
sudo pacman -S wget git python3 sudo pacman -S wget git python3
``` ```
@ -146,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
- CodeFormer - https://github.com/sczhou/CodeFormer - CodeFormer - https://github.com/sczhou/CodeFormer
@ -173,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf - LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)

View File

@ -0,0 +1,73 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: modules.xlmr_m18.BertSeriesModelWithTransformation
params:
name: "XLMR-Large"

View File

@ -0,0 +1,33 @@
import sys
import copy
import logging
class ColoredFormatter(logging.Formatter):
COLORS = {
"DEBUG": "\033[0;36m", # CYAN
"INFO": "\033[0;32m", # GREEN
"WARNING": "\033[0;33m", # YELLOW
"ERROR": "\033[0;31m", # RED
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
"RESET": "\033[0m", # RESET COLOR
}
def format(self, record):
colored_record = copy.copy(record)
levelname = colored_record.levelname
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
return super().format(colored_record)
logger = logging.getLogger("lora")
logger.propagate = False
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
)
logger.addHandler(handler)

View File

@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1) up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1) down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, AB BA. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -93,6 +93,7 @@ class Network: # LoraModule
self.unet_multiplier = 1.0 self.unet_multiplier = 1.0
self.dyn_dim = None self.dyn_dim = None
self.modules = {} self.modules = {}
self.bundle_embeddings = {}
self.mtime = None self.mtime = None
self.mentioned_name = None self.mentioned_name = None

View File

@ -0,0 +1,33 @@
import network
class ModuleTypeGLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
return NetworkModuleGLora(net, weights)
return None
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
class NetworkModuleGLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["a1.weight"]
self.w1b = weights.w["b1.weight"]
self.w2a = weights.w["a2.weight"]
self.w2b = weights.w["b2.weight"]
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
return self.finalize_updown(updown, orig_weight, output_shape)

View File

@ -0,0 +1,97 @@
import torch
import network
from lyco_helpers import factorization
from einops import rearrange
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
return NetworkModuleOFT(net, weights)
return None
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
if is_linear:
self.out_dim = self.sd_module.out_features
elif is_conv:
self.out_dim = self.sd_module.out_channels
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
if self.is_kohya:
self.constraint = self.alpha * self.out_dim
self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
def calc_updown_kb(self, orig_weight, multiplier):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device)
# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R,
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
# if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it
multiplier = self.multiplier()
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown, ex_bias

View File

@ -5,16 +5,21 @@ import re
import lora_patches import lora_patches
import network import network
import network_lora import network_lora
import network_glora
import network_hada import network_hada
import network_ia3 import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm import network_norm
import network_oft
import torch import torch
from typing import Union from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack from modules import shared, devices, sd_models, errors, scripts, sd_hijack
import modules.textual_inversion.textual_inversion as textual_inversion
from lora_logger import logger
module_types = [ module_types = [
network_lora.ModuleTypeLora(), network_lora.ModuleTypeLora(),
@ -23,6 +28,8 @@ module_types = [
network_lokr.ModuleTypeLokr(), network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(), network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
] ]
@ -149,9 +156,19 @@ def load_network(name, network_on_disk):
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {} matched_networks = {}
bundle_embeddings = {}
for key_network, weight in sd.items(): for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1) key_network_without_network_parts, network_part = key_network.split(".", 1)
if key_network_without_network_parts == "bundle_emb":
emb_name, vec_name = network_part.split(".", 1)
emb_dict = bundle_embeddings.get(emb_name, {})
if vec_name.split('.')[0] == 'string_to_param':
_, k2 = vec_name.split('.', 1)
emb_dict['string_to_param'] = {k2: weight}
else:
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@ -174,6 +191,17 @@ def load_network(name, network_on_disk):
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# kohya_ss OFT module
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# KohakuBlueLeaf OFT module
if sd_module is None and "oft_diag" in key:
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_network] = key keys_failed_to_match[key_network] = key
continue continue
@ -195,6 +223,14 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module net.modules[key] = net_module
embeddings = {}
for emb_name, data in bundle_embeddings.items():
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
embedding.loaded = None
embeddings[emb_name] = embedding
net.bundle_embeddings = embeddings
if keys_failed_to_match: if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@ -210,11 +246,15 @@ def purge_networks_from_memory():
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {} already_loaded = {}
for net in loaded_networks: for net in loaded_networks:
if net.name in names: if net.name in names:
already_loaded[net.name] = net already_loaded[net.name] = net
for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded:
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear() loaded_networks.clear()
@ -257,6 +297,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net) loaded_networks.append(net)
for emb_name, embedding in net.bundle_embeddings.items():
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
logger.warning(
f'Skip bundle embedding: "{emb_name}"'
' as it was already loaded from embeddings folder'
)
continue
embedding.loaded = False
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
embedding.loaded = True
emb_db.register_embedding(embedding, shared.sd_model)
else:
emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@ -418,6 +473,7 @@ def network_forward(module, input, original_forward):
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.network_current_names = () self.network_current_names = ()
self.network_weights_backup = None self.network_weights_backup = None
self.network_bias_backup = None
def network_Linear_forward(self, input): def network_Linear_forward(self, input):
@ -564,6 +620,7 @@ extra_network_lora = None
available_networks = {} available_networks = {}
available_network_aliases = {} available_network_aliases = {}
loaded_networks = [] loaded_networks = []
loaded_bundle_embeddings = {}
networks_in_memory = {} networks_in_memory = {}
available_network_hash_lookup = {} available_network_hash_lookup = {}
forbidden_network_aliases = {} forbidden_network_aliases = {}

View File

@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = networks.available_networks.get(name) lora_on_disk = networks.available_networks.get(name)
if lora_on_disk is None:
return
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item return item
def list_items(self): def list_items(self):
for index, name in enumerate(networks.available_networks): # instantiate a list to protect against concurrent modification
names = list(networks.available_networks)
for index, name in enumerate(names):
item = self.create_item(name, index) item = self.create_item(name, index)
if item is not None: if item is not None:
yield item yield item

View File

@ -0,0 +1,345 @@
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
"""
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Callable
from functools import wraps, cache
import math
import torch.nn as nn
import random
from einops import rearrange
@dataclass
class HypertileParams:
depth = 0
layer_name = ""
tile_size: int = 0
swap_size: int = 0
aspect_ratio: float = 1.0
forward = None
enabled = False
# TODO add SD-XL layers
DEPTH_LAYERS = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.1.1.transformer_blocks.0.attn1",
"input_blocks.2.1.transformer_blocks.0.attn1",
"output_blocks.9.1.transformer_blocks.0.attn1",
"output_blocks.10.1.transformer_blocks.0.attn1",
"output_blocks.11.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.6.1.transformer_blocks.0.attn1",
"output_blocks.7.1.transformer_blocks.0.attn1",
"output_blocks.8.1.transformer_blocks.0.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
],
3: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
],
}
# XL layers, thanks for GitHub@gel-crabs for the help
DEPTH_LAYERS_XL = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.1.attn1",
"input_blocks.5.1.transformer_blocks.1.attn1",
"output_blocks.3.1.transformer_blocks.1.attn1",
"output_blocks.4.1.transformer_blocks.1.attn1",
"output_blocks.5.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.0.1.transformer_blocks.0.attn1",
"output_blocks.1.1.transformer_blocks.0.attn1",
"output_blocks.2.1.transformer_blocks.0.attn1",
"input_blocks.7.1.transformer_blocks.1.attn1",
"input_blocks.8.1.transformer_blocks.1.attn1",
"output_blocks.0.1.transformer_blocks.1.attn1",
"output_blocks.1.1.transformer_blocks.1.attn1",
"output_blocks.2.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.2.attn1",
"input_blocks.8.1.transformer_blocks.2.attn1",
"output_blocks.0.1.transformer_blocks.2.attn1",
"output_blocks.1.1.transformer_blocks.2.attn1",
"output_blocks.2.1.transformer_blocks.2.attn1",
"input_blocks.7.1.transformer_blocks.3.attn1",
"input_blocks.8.1.transformer_blocks.3.attn1",
"output_blocks.0.1.transformer_blocks.3.attn1",
"output_blocks.1.1.transformer_blocks.3.attn1",
"output_blocks.2.1.transformer_blocks.3.attn1",
"input_blocks.7.1.transformer_blocks.4.attn1",
"input_blocks.8.1.transformer_blocks.4.attn1",
"output_blocks.0.1.transformer_blocks.4.attn1",
"output_blocks.1.1.transformer_blocks.4.attn1",
"output_blocks.2.1.transformer_blocks.4.attn1",
"input_blocks.7.1.transformer_blocks.5.attn1",
"input_blocks.8.1.transformer_blocks.5.attn1",
"output_blocks.0.1.transformer_blocks.5.attn1",
"output_blocks.1.1.transformer_blocks.5.attn1",
"output_blocks.2.1.transformer_blocks.5.attn1",
"input_blocks.7.1.transformer_blocks.6.attn1",
"input_blocks.8.1.transformer_blocks.6.attn1",
"output_blocks.0.1.transformer_blocks.6.attn1",
"output_blocks.1.1.transformer_blocks.6.attn1",
"output_blocks.2.1.transformer_blocks.6.attn1",
"input_blocks.7.1.transformer_blocks.7.attn1",
"input_blocks.8.1.transformer_blocks.7.attn1",
"output_blocks.0.1.transformer_blocks.7.attn1",
"output_blocks.1.1.transformer_blocks.7.attn1",
"output_blocks.2.1.transformer_blocks.7.attn1",
"input_blocks.7.1.transformer_blocks.8.attn1",
"input_blocks.8.1.transformer_blocks.8.attn1",
"output_blocks.0.1.transformer_blocks.8.attn1",
"output_blocks.1.1.transformer_blocks.8.attn1",
"output_blocks.2.1.transformer_blocks.8.attn1",
"input_blocks.7.1.transformer_blocks.9.attn1",
"input_blocks.8.1.transformer_blocks.9.attn1",
"output_blocks.0.1.transformer_blocks.9.attn1",
"output_blocks.1.1.transformer_blocks.9.attn1",
"output_blocks.2.1.transformer_blocks.9.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
"middle_block.1.transformer_blocks.1.attn1",
"middle_block.1.transformer_blocks.2.attn1",
"middle_block.1.transformer_blocks.3.attn1",
"middle_block.1.transformer_blocks.4.attn1",
"middle_block.1.transformer_blocks.5.attn1",
"middle_block.1.transformer_blocks.6.attn1",
"middle_block.1.transformer_blocks.7.attn1",
"middle_block.1.transformer_blocks.8.attn1",
"middle_block.1.transformer_blocks.9.attn1",
],
3 : [] # TODO - separate layers for SD-XL
}
RNG_INSTANCE = random.Random()
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
"""
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
"""
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
return ns[idx]
def set_hypertile_seed(seed: int) -> None:
RNG_INSTANCE.seed(seed)
@functools.cache
def largest_tile_size_available(width: int, height: int) -> int:
"""
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
"""
gcd = math.gcd(width, height)
largest_tile_size_available = 1
while gcd % (largest_tile_size_available * 2) == 0:
largest_tile_size_available *= 2
return largest_tile_size_available
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
"""
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
"""
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
# find h and w such that h*w = hw and h/w = aspect_ratio
if h * w != hw:
w_candidate = hw / h
# check if w is an integer
if not w_candidate.is_integer():
h_candidate = hw / w
# check if h is an integer
if not h_candidate.is_integer():
return iterative_closest_divisors(hw, aspect_ratio)
else:
h = int(h_candidate)
else:
w = int(w_candidate)
return h, w
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
@wraps(params.forward)
def wrapper(*args, **kwargs):
if not params.enabled:
return params.forward(*args, **kwargs)
latent_tile_size = max(128, params.tile_size) // 8
x = args[0]
# VAE
if x.ndim == 4:
b, c, h, w = x.shape
nh = random_divisor(h, latent_tile_size, params.swap_size)
nw = random_divisor(w, latent_tile_size, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
# U-Net
else:
hw: int = x.size(1)
h, w = find_hw_candidates(hw, params.aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2 ** params.depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
if hypertile_layers is None:
if not enable:
return
hypertile_layers = {}
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
for depth in range(4):
for layer_name, module in model.named_modules():
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
params = HypertileParams()
module.__webui_hypertile_params = params
params.forward = module.forward
params.depth = depth
params.layer_name = layer_name
module.forward = self_attn_forward(params)
hypertile_layers[layer_name] = 1
model.__webui_hypertile_layers = hypertile_layers
aspect_ratio = width / height
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
for layer_name, module in model.named_modules():
if layer_name in hypertile_layers:
params = module.__webui_hypertile_params
params.tile_size = tile_size
params.swap_size = swap_size
params.aspect_ratio = aspect_ratio
params.enabled = enable and params.depth <= max_depth

View File

@ -0,0 +1,73 @@
import hypertile
from modules import scripts, script_callbacks, shared
class ScriptHypertile(scripts.Script):
name = "Hypertile"
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def process(self, p, *args):
hypertile.set_hypertile_seed(p.all_seeds[0])
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
def before_hr(self, p, *args):
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)
def configure_hypertile(width, height, enable_unet=True):
hypertile.hypertile_hook_model(
shared.sd_model.first_stage_model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_vae,
max_depth=shared.opts.hypertile_max_depth_vae,
tile_size_max=shared.opts.hypertile_max_tile_vae,
enable=shared.opts.hypertile_enable_vae,
)
hypertile.hypertile_hook_model(
shared.sd_model.model,
width,
height,
swap_size=shared.opts.hypertile_swap_size_unet,
max_depth=shared.opts.hypertile_max_depth_unet,
tile_size_max=shared.opts.hypertile_max_tile_unet,
enable=enable_unet,
is_sdxl=shared.sd_model.is_sdxl
)
def on_ui_settings():
import gradio as gr
options = {
"hypertile_explanation": shared.OptionHTML("""
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
benefit.
"""),
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
}
for name, opt in options.items():
opt.section = ('hypertile', "Hypertile")
shared.opts.add_option(name, opt)
script_callbacks.on_ui_settings(on_ui_settings)

View File

@ -12,6 +12,8 @@ function isMobile() {
} }
function reportWindowSize() { function reportWindowSize() {
if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
var currentlyMobile = isMobile(); var currentlyMobile = isMobile();
if (currentlyMobile == isSetupForMobile) return; if (currentlyMobile == isSetupForMobile) return;
isSetupForMobile = currentlyMobile; isSetupForMobile = currentlyMobile;

View File

@ -119,7 +119,7 @@ window.addEventListener('paste', e => {
} }
const firstFreeImageField = visibleImageFields const firstFreeImageField = visibleImageFields
.filter(el => el.querySelector('input[type=file]'))?.[0]; .filter(el => !el.querySelector('img'))?.[0];
dropReplaceImage( dropReplaceImage(
firstFreeImageField ? firstFreeImageField ?

View File

@ -18,37 +18,43 @@ function keyupEditAttention(event) {
const before = text.substring(0, selectionStart); const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN); let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false; if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { let beforeClosingParen = before.lastIndexOf(CLOSE);
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
}
// Find closing parenthesis around current cursor // Find closing parenthesis around current cursor
const after = text.substring(selectionStart); const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE); let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false; if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) { let afterOpeningParen = after.indexOf(OPEN);
afterParen = after.indexOf(CLOSE, afterParen + 1); if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis // Set the selection to the text between the parenthesis
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":"); if (/.*:-?[\d.]+/s.test(parenContent)) {
selectionStart = beforeParen + 1; const lastColon = parenContent.lastIndexOf(":");
selectionEnd = selectionStart + lastColon; selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
} else {
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + parenContent.length;
}
target.setSelectionRange(selectionStart, selectionEnd); target.setSelectionRange(selectionStart, selectionEnd);
return true; return true;
} }
function selectCurrentWord() { function selectCurrentWord() {
if (selectionStart !== selectionEnd) return false; if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t"; const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"};
let delimiters = opts.keyedit_delimiters;
// seek backward until to find beggining for (let i of opts.keyedit_delimiters_whitespace) {
delimiters += whitespace_delimiters[i];
}
// seek backward to find beginning
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) { while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--; selectionStart--;
} }
@ -63,7 +69,7 @@ function keyupEditAttention(event) {
} }
// If the user hasn't selected anything, let's select their current parenthesis block or word // If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) { if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
selectCurrentWord(); selectCurrentWord();
} }
@ -71,33 +77,54 @@ function keyupEditAttention(event) {
var closeCharacter = ')'; var closeCharacter = ')';
var delta = opts.keyedit_precision_attention; var delta = opts.keyedit_precision_attention;
var start = selectionStart > 0 ? text[selectionStart - 1] : "";
var end = text[selectionEnd];
if (selectionStart > 0 && text[selectionStart - 1] == '<') { if (start == '<') {
closeCharacter = '>'; closeCharacter = '>';
delta = opts.keyedit_precision_extra; delta = opts.keyedit_precision_extra;
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") { } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
let numParen = 0;
while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
numParen++;
}
if (start == "[") {
weight = (1 / 1.1) ** numParen;
} else {
weight = 1.1 ** numParen;
}
weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
selectionStart -= numParen - 1;
selectionEnd -= numParen - 1;
} else if (start != '(') {
// do not include spaces at the end // do not include spaces at the end
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') { while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
selectionEnd -= 1; selectionEnd--;
} }
if (selectionStart == selectionEnd) { if (selectionStart == selectionEnd) {
return; return;
} }
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
selectionStart += 1; selectionStart++;
selectionEnd += 1; selectionEnd++;
} }
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; if (text[selectionEnd] != ':') return;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12)); weight = parseFloat(weight.toPrecision(12));
if (String(weight).length == 1) weight += ".0"; if (Number.isInteger(weight)) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
var endParenPos = text.substring(selectionEnd).indexOf(')'); var endParenPos = text.substring(selectionEnd).indexOf(')');
@ -105,7 +132,7 @@ function keyupEditAttention(event) {
selectionStart--; selectionStart--;
selectionEnd--; selectionEnd--;
} else { } else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
} }
target.focus(); target.focus();

View File

@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
sort.dataset.sortkey = 'sortDefault';
tabs.appendChild(searchDiv); tabs.appendChild(searchDiv);
tabs.appendChild(sort); tabs.appendChild(sort);
tabs.appendChild(sortOrder); tabs.appendChild(sortOrder);
@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
elem.style.display = visible ? "" : "none"; elem.style.display = visible ? "" : "none";
}); });
applySort();
}; };
var applySort = function() { var applySort = function() {
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
var reverse = sortOrder.classList.contains("sortReverse"); var reverse = sortOrder.classList.contains("sortReverse");
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
if (sortKeyStore == sort.dataset.sortkey) {
return; return;
} }
sort.dataset.sortkey = sortKeyStore; sort.dataset.sortkey = sortKeyStore;
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
cards.forEach(function(card) { cards.forEach(function(card) {
card.originalParentElement = card.parentElement; card.originalParentElement = card.parentElement;
}); });
@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
}; };
search.addEventListener("input", applyFilter); search.addEventListener("input", applyFilter);
applyFilter();
["change", "blur", "click"].forEach(function(evt) {
sort.querySelector("input").addEventListener(evt, applySort);
});
sortOrder.addEventListener("click", function() { sortOrder.addEventListener("click", function() {
sortOrder.classList.toggle("sortReverse"); sortOrder.classList.toggle("sortReverse");
applySort(); applySort();
}); });
applyFilter();
extraNetworksApplySort[tabname] = applySort;
extraNetworksApplyFilter[tabname] = applyFilter; extraNetworksApplyFilter[tabname] = applyFilter;
var showDirsUpdate = function() { var showDirsUpdate = function() {
@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) {
showDirsUpdate(); showDirsUpdate();
} }
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
var prompt = gradioApp().getElementById(tabname + '_prompt_row');
var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
var elem = id ? gradioApp().getElementById(id) : null;
if (showNegativePrompt && elem) {
elem.insertBefore(negPrompt, elem.firstChild);
} else {
promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
}
if (showPrompt && elem) {
elem.insertBefore(prompt, elem.firstChild);
} else {
promptContainer.insertBefore(prompt, promptContainer.firstChild);
}
if (elem) {
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
}
}
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
extraNetworksMovePromptToTab(tabname, '', false, false);
}
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
}
function applyExtraNetworkFilter(tabname) { function applyExtraNetworkFilter(tabname) {
setTimeout(extraNetworksApplyFilter[tabname], 1); setTimeout(extraNetworksApplyFilter[tabname], 1);
} }
function applyExtraNetworkSort(tabname) {
setTimeout(extraNetworksApplySort[tabname], 1);
}
var extraNetworksApplyFilter = {}; var extraNetworksApplyFilter = {};
var extraNetworksApplySort = {};
var activePromptTextarea = {}; var activePromptTextarea = {};
function setupExtraNetworks() { function setupExtraNetworks() {
@ -140,14 +182,15 @@ function setupExtraNetworks() {
onUiLoaded(setupExtraNetworks); onUiLoaded(setupExtraNetworks);
var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
function tryToRemoveExtraNetworkFromPrompt(textarea, text) { function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
var m = text.match(re_extranet); var m = text.match(re_extranet);
var replaced = false; var replaced = false;
var newTextareaText; var newTextareaText;
if (m) { if (m) {
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
var extraTextAfterNet = m[2]; var extraTextAfterNet = m[2];
var partToSearch = m[1]; var partToSearch = m[1];
var foundAtPosition = -1; var foundAtPosition = -1;
@ -161,8 +204,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return found; return found;
}); });
if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { if (foundAtPosition >= 0) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
}
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);
}
} }
} else { } else {
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
@ -216,27 +264,24 @@ function extraNetworksSearchButton(tabs_id, event) {
var globalPopup = null; var globalPopup = null;
var globalPopupInner = null; var globalPopupInner = null;
function closePopup() { function closePopup() {
if (!globalPopup) return; if (!globalPopup) return;
globalPopup.style.display = "none"; globalPopup.style.display = "none";
} }
function popup(contents) { function popup(contents) {
if (!globalPopup) { if (!globalPopup) {
globalPopup = document.createElement('div'); globalPopup = document.createElement('div');
globalPopup.onclick = closePopup;
globalPopup.classList.add('global-popup'); globalPopup.classList.add('global-popup');
var close = document.createElement('div'); var close = document.createElement('div');
close.classList.add('global-popup-close'); close.classList.add('global-popup-close');
close.onclick = closePopup; close.addEventListener("click", closePopup);
close.title = "Close"; close.title = "Close";
globalPopup.appendChild(close); globalPopup.appendChild(close);
globalPopupInner = document.createElement('div'); globalPopupInner = document.createElement('div');
globalPopupInner.onclick = function(event) {
event.stopPropagation(); return false;
};
globalPopupInner.classList.add('global-popup-inner'); globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner); globalPopup.appendChild(globalPopupInner);
@ -335,7 +380,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
function extraNetworksRefreshSingleCard(page, tabname, name) { function extraNetworksRefreshSingleCard(page, tabname, name) {
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
if (data && data.html) { if (data && data.html) {
var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`);
var newDiv = document.createElement('DIV'); var newDiv = document.createElement('DIV');
newDiv.innerHTML = data.html; newDiv.innerHTML = data.html;

View File

@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) { if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button(); let currentButton = selected_gallery_button();
let preview = gradioApp().querySelectorAll('.livePreview > img');
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { if (preview.length > 0) {
// show preview image if available
modalImage.src = preview[preview.length - 1].src;
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src; modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");

View File

@ -1,37 +1,68 @@
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
var elem = mutationRecord.target;
var open = elem.classList.contains('open');
var accordion = elem.parentNode;
accordion.classList.toggle('input-accordion-open', open);
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
checkbox.checked = open;
updateInput(checkbox);
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
extra.style.display = open ? "" : "none";
}
});
});
function inputAccordionChecked(id, checked) { function inputAccordionChecked(id, checked) {
var label = gradioApp().querySelector('#' + id + " .label-wrap"); var accordion = gradioApp().getElementById(id);
if (label.classList.contains('open') != checked) { accordion.visibleCheckbox.checked = checked;
label.click(); accordion.onVisibleCheckboxChange();
}
function setupAccordion(accordion) {
var labelWrap = accordion.querySelector('.label-wrap');
var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
var span = labelWrap.querySelector('span');
var linked = true;
var isOpen = function() {
return labelWrap.classList.contains('open');
};
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
accordion.classList.toggle('input-accordion-open', isOpen());
if (linked) {
accordion.visibleCheckbox.checked = isOpen();
accordion.onVisibleCheckboxChange();
}
});
});
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
} }
accordion.onChecked = function(checked) {
if (isOpen() != checked) {
labelWrap.click();
}
};
var visibleCheckbox = document.createElement('INPUT');
visibleCheckbox.type = 'checkbox';
visibleCheckbox.checked = isOpen();
visibleCheckbox.id = accordion.id + "-visible-checkbox";
visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
span.insertBefore(visibleCheckbox, span.firstChild);
accordion.visibleCheckbox = visibleCheckbox;
accordion.onVisibleCheckboxChange = function() {
if (linked && isOpen() != visibleCheckbox.checked) {
labelWrap.click();
}
gradioCheckbox.checked = visibleCheckbox.checked;
updateInput(gradioCheckbox);
};
visibleCheckbox.addEventListener('click', function(event) {
linked = false;
event.stopPropagation();
});
visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
} }
onUiLoaded(function() { onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
var labelWrap = accordion.querySelector('.label-wrap'); setupAccordion(accordion);
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
} }
}); });

View File

@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
lastHeadImg = headImg; lastHeadImg = headImg;
// play notification sound if available // play notification sound if available
gradioApp().querySelector('#audio_notification audio')?.play(); const notificationAudio = gradioApp().querySelector('#audio_notification audio');
if (notificationAudio) {
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
notificationAudio.play();
}
if (document.hasFocus()) return; if (document.hasFocus()) return;

71
javascript/settings.js Normal file
View File

@ -0,0 +1,71 @@
let settingsExcludeTabsFromShowAll = {
settings_tab_defaults: 1,
settings_tab_sysinfo: 1,
settings_tab_actions: 1,
settings_tab_licenses: 1,
};
function settingsShowAllTabs() {
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
if (settingsExcludeTabsFromShowAll[elem.id]) return;
elem.style.display = "block";
});
}
function settingsShowOneTab() {
gradioApp().querySelector('#settings_show_one_page').click();
}
onUiLoaded(function() {
var edit = gradioApp().querySelector('#settings_search');
var editTextarea = gradioApp().querySelector('#settings_search > label > input');
var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');
var settings_tabs = gradioApp().querySelector('#settings div');
onEdit('settingsSearch', editTextarea, 250, function() {
var searchText = (editTextarea.value || "").trim().toLowerCase();
gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {
var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;
elem.style.display = visible ? "" : "none";
});
if (searchText != "") {
settingsShowAllTabs();
} else {
settingsShowOneTab();
}
});
settings_tabs.insertBefore(edit, settings_tabs.firstChild);
settings_tabs.appendChild(buttonShowAllPages);
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
});
onOptionsChanged(function() {
if (gradioApp().querySelector('#settings .settings-category')) return;
var sectionMap = {};
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
sectionMap[x.textContent.trim()] = x;
});
opts._categories.forEach(function(x) {
var section = x[0];
var category = x[1];
var span = document.createElement('SPAN');
span.textContent = category;
span.className = 'settings-category';
var sectionElem = sectionMap[section];
if (!sectionElem) return;
sectionElem.parentElement.insertBefore(span, sectionElem);
});
});

View File

@ -1,10 +1,9 @@
let promptTokenCountDebounceTime = 800; let promptTokenCountUpdateFunctions = {};
let promptTokenCountTimeouts = {};
var promptTokenCountUpdateFunctions = {};
function update_txt2img_tokens(...args) { function update_txt2img_tokens(...args) {
// Called from Gradio // Called from Gradio
update_token_counter("txt2img_token_button"); update_token_counter("txt2img_token_button");
update_token_counter("txt2img_negative_token_button");
if (args.length == 2) { if (args.length == 2) {
return args[0]; return args[0];
} }
@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) {
function update_img2img_tokens(...args) { function update_img2img_tokens(...args) {
// Called from Gradio // Called from Gradio
update_token_counter("img2img_token_button"); update_token_counter("img2img_token_button");
update_token_counter("img2img_negative_token_button");
if (args.length == 2) { if (args.length == 2) {
return args[0]; return args[0];
} }
@ -21,16 +21,7 @@ function update_img2img_tokens(...args) {
} }
function update_token_counter(button_id) { function update_token_counter(button_id) {
if (opts.disable_token_counters) { promptTokenCountUpdateFunctions[button_id]?.();
return;
}
if (promptTokenCountTimeouts[button_id]) {
clearTimeout(promptTokenCountTimeouts[button_id]);
}
promptTokenCountTimeouts[button_id] = setTimeout(
() => gradioApp().getElementById(button_id)?.click(),
promptTokenCountDebounceTime,
);
} }
@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) {
prompt.parentElement.insertBefore(counter, prompt); prompt.parentElement.insertBefore(counter, prompt);
prompt.parentElement.style.position = "relative"; prompt.parentElement.style.position = "relative";
promptTokenCountUpdateFunctions[id] = function() { var func = onEdit(id, textarea, 800, function() {
update_token_counter(id_button); gradioApp().getElementById(id_button)?.click();
}; });
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); promptTokenCountUpdateFunctions[id] = func;
promptTokenCountUpdateFunctions[id_button] = func;
} }
function setupTokenCounters() { function setupTokenCounters() {

View File

@ -263,21 +263,6 @@ onAfterUiUpdate(function() {
json_elem.parentElement.style.display = "none"; json_elem.parentElement.style.display = "none";
setupTokenCounters(); setupTokenCounters();
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
var settings_tabs = gradioApp().querySelector('#settings div');
if (show_all_pages && settings_tabs) {
settings_tabs.appendChild(show_all_pages);
show_all_pages.onclick = function() {
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
if (elem.id == "settings_tab_licenses") {
return;
}
elem.style.display = "block";
});
};
}
}); });
onOptionsChanged(function() { onOptionsChanged(function() {
@ -366,3 +351,20 @@ function switchWidthHeight(tabname) {
updateInput(height); updateInput(height);
return []; return [];
} }
var onEditTimers = {};
// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
function onEdit(editId, elem, afterMs, func) {
var edited = function() {
var existingTimer = onEditTimers[editId];
if (existingTimer) clearTimeout(existingTimer);
onEditTimers[editId] = setTimeout(func, afterMs);
};
elem.addEventListener("input", edited);
return edited;
}

View File

@ -17,19 +17,18 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin, Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import Dict, List, Any from typing import Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing from contextlib import closing
@ -103,7 +102,8 @@ def decode_base64_to_image(encoding):
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str):
return image
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -221,15 +221,15 @@ class Api:
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@ -242,7 +242,8 @@ class Api:
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
if shared.cmd_opts.api_server_stop: if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@ -473,9 +474,6 @@ class Api:
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return models.PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
@ -484,9 +482,10 @@ class Api:
if geninfo is None: if geninfo is None:
geninfo = "" geninfo = ""
items = {**{'parameters': geninfo}, **items} params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.PNGInfoResponse(info=geninfo, items=items) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
def progressapi(self, req: models.ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
@ -541,12 +540,12 @@ class Api:
return {} return {}
def unloadapi(self): def unloadapi(self):
unload_model_weights() sd_models.unload_model_weights()
return {} return {}
def reloadapi(self): def reloadapi(self):
reload_model_weights() sd_models.send_model_to_device(shared.sd_model)
return {} return {}
@ -564,9 +563,9 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None) checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found") raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
@ -770,6 +769,25 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def get_extensions_list(self):
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info_from_repo()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def launch(self, server_name, port, root_path): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)

View File

@ -1,12 +1,10 @@
import inspect import inspect
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional from typing import Any, Optional, Literal
from typing_extensions import Literal
from inflection import underscore from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers, opts, parser from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
API_NOT_ALLOWED = [ API_NOT_ALLOWED = [
"self", "self",
@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
@ -168,17 +166,18 @@ class FileData(BaseModel):
name: str = Field(title="File name") name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest): class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse): class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: List[str] = Field(title="Images", description="The generated images in base64 format.") images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel): class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel): class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with the parameters used to generate the image") info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
items: dict = Field(title="Items", description="An object containing all the info the image had") items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
class ProgressRequest(BaseModel): class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@ -232,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel): class SamplerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases") aliases: list[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options") options: dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel): class UpscalerItem(BaseModel):
name: str = Field(title="Name") name: str = Field(title="Name")
@ -284,8 +283,8 @@ class EmbeddingItem(BaseModel):
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
class EmbeddingsResponse(BaseModel): class EmbeddingsResponse(BaseModel):
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
class MemoryResponse(BaseModel): class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats") ram: dict = Field(title="RAM", description="System memory stats")
@ -303,11 +302,20 @@ class ScriptArg(BaseModel):
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI") minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI") maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI") step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument") choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
class ScriptInfo(BaseModel): class ScriptInfo(BaseModel):
name: str = Field(default=None, title="Name", description="Script name") name: str = Field(default=None, title="Name", description="Script name")
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
class ExtensionItem(BaseModel):
name: str = Field(title="Name", description="Extension name")
remote: str = Field(title="Remote", description="Extension Repository URL")
branch: str = Field(title="Branch", description="Extension Repository Branch")
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
version: str = Field(title="Version", description="Extension Version")
commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")

View File

@ -32,7 +32,7 @@ def dump_cache():
with cache_lock: with cache_lock:
cache_filename_tmp = cache_filename + "-" cache_filename_tmp = cache_filename + "-"
with open(cache_filename_tmp, "w", encoding="utf8") as file: with open(cache_filename_tmp, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4) json.dump(cache_data, file, indent=4, ensure_ascii=False)
os.replace(cache_filename_tmp, cache_filename) os.replace(cache_filename_tmp, cache_filename)

View File

@ -90,7 +90,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
@ -107,13 +107,14 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )

View File

@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
import os import os
import json import json
import time
import tqdm import tqdm
from datetime import datetime from datetime import datetime
@ -38,7 +37,7 @@ def list_config_states():
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states: for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"])) timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
name = cs.get("name", "Config") name = cs.get("name", "Config")
full_name = f"{name}: {timestamp}" full_name = f"{name}: {timestamp}"
all_config_states[full_name] = cs all_config_states[full_name] = cs

View File

@ -60,7 +60,8 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True

View File

@ -6,6 +6,21 @@ import traceback
exception_records = [] exception_records = []
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions():
try:
return list(reversed(exception_records))
except Exception as e:
return str(e)
def record_exception(): def record_exception():
_, e, tb = sys.exc_info() _, e, tb = sys.exc_info()
if e is None: if e is None:
@ -14,8 +29,7 @@ def record_exception():
if exception_records and exception_records[-1] == e: if exception_records and exception_records[-1] == e:
return return
from modules import sysinfo exception_records.append(format_exception(e, tb))
exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5: if len(exception_records) > 5:
exception_records.pop(0) exception_records.pop(0)

View File

@ -1,11 +1,14 @@
from __future__ import annotations
import configparser
import os import os
import threading import threading
import re
from modules import shared, errors, cache, scripts from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
os.makedirs(extensions_dir, exist_ok=True) os.makedirs(extensions_dir, exist_ok=True)
@ -19,11 +22,55 @@ def active():
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
canonical_name: str
requires: list
def __init__(self, path, canonical_name):
self.config = configparser.ConfigParser()
filepath = os.path.join(path, self.filename)
if os.path.isfile(filepath):
try:
self.config.read(filepath)
except Exception:
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
self.canonical_name = canonical_name.lower().strip()
self.requires = self.get_script_requirements("Requires", "Extension")
def get_script_requirements(self, field, section, extra_section=None):
"""reads a list of requirements from the config; field is the name of the field in the ini file,
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
reads more requirements from [extra_section] if specified."""
x = self.config.get(section, field, fallback='')
if extra_section:
x = x + ', ' + self.config.get(extra_section, field, fallback='')
return self.parse_list(x.lower())
def parse_list(self, text):
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
if not text:
return []
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
class Extension: class Extension:
lock = threading.Lock() lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
metadata: ExtensionMetadata
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name self.name = name
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
@ -36,6 +83,8 @@ class Extension:
self.branch = None self.branch = None
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
self.canonical_name = metadata.canonical_name
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}
@ -56,6 +105,7 @@ class Extension:
self.do_read_info_from_repo() self.do_read_info_from_repo()
return self.to_dict() return self.to_dict()
try: try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d) self.from_dict(d)
@ -136,9 +186,6 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(extensions_dir):
return
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all": elif shared.opts.disable_all_extensions == "all":
@ -148,18 +195,43 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] loaded_extensions = {}
for dirname in [extensions_dir, extensions_builtin_dir]:
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return continue
for extension_dirname in sorted(os.listdir(dirname)): for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname) path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) canonical_name = extension_dirname
metadata = ExtensionMetadata(path, canonical_name)
for dirname, path, is_builtin in extension_paths: # check for duplicated canonical names
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
extensions.append(extension) if already_loaded_extension is not None:
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
continue
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
loaded_extensions[canonical_name] = extension
# check for requirements
for extension in extensions:
for req in extension.metadata.requires:
required_extension = loaded_extensions.get(req)
if required_extension is None:
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
if not extension.enabled:
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
continue
extensions: list[Extension] = []

View File

@ -9,7 +9,7 @@ from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")

View File

@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
model_path = os.path.join(paths.models_path, model_dir) model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
@ -17,6 +18,7 @@ loaded_gfpgan_model = None
def gfpgann(): def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
global model_file_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
@ -24,17 +26,24 @@ def gfpgann():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
if len(models) == 1 and models[0].startswith("http"): if len(models) == 1 and models[0].startswith("http"):
model_file = models[0] model_file = models[0]
elif len(models) != 0: elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) gfp_models = []
for item in models:
if 'GFPGAN' in os.path.basename(item):
gfp_models.append(item)
latest_file = max(gfp_models, key=os.path.getctime)
model_file = latest_file model_file = latest_file
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
if hasattr(facexlib.detection.retinaface, 'device'): if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
@ -77,19 +86,25 @@ def setup_model(dirname):
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
global model_file_path
facexlib_path = model_path
if dirname is not None:
facexlib_path = dirname
load_file_from_url_orig = gfpgan.utils.load_file_from_url load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs): def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
def facex_load_file_from_url(**kwargs): def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
def facex_load_file_from_url2(**kwargs): def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url facexlib.detection.load_file_from_url = facex_load_file_from_url

View File

@ -23,7 +23,7 @@ class Git(git.Git):
) )
return self._parse_object_header(ret) return self._parse_object_header(ret)
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
# Not really streaming, per se; this buffers the entire object in memory. # Not really streaming, per se; this buffers the entire object in memory.
# Shouldn't be a problem for our use case, since we're only using this for # Shouldn't be a problem for our use case, since we're only using this for
# object headers (commit objects). # object headers (commit objects).

View File

@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks() shared.reload_hypernetworks()
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
from modules import images, processing from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View File

@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
}) })
piexif.insert(exif_bytes, filename) piexif.insert(exif_bytes, filename)
elif extension.lower() == ".gif":
image.save(filename, format=image_format, comment=geninfo)
else: else:
image.save(filename, format=image_format, quality=opts.jpeg_quality) image.save(filename, format=image_format, quality=opts.jpeg_quality)
@ -661,7 +663,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension) filename = filename_without_extension + extension
if shared.opts.save_images_replace_action != "Replace":
n = 0
while os.path.exists(filename):
n += 1
filename = f"{filename_without_extension}-{n}{extension}"
os.replace(temp_file_path, filename)
fullfn_without_extension, extension = os.path.splitext(params.filename) fullfn_without_extension, extension = os.path.splitext(params.filename)
if hasattr(os, 'statvfs'): if hasattr(os, 'statvfs'):
@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
geninfo = items.pop('parameters', None) geninfo = items.pop('parameters', None)
if "exif" in items: if "exif" in items:
exif = piexif.load(items["exif"]) exif_data = items["exif"]
try:
exif = piexif.load(exif_data)
except OSError:
# memory / exif was not valid so piexif tried to read from a file
exif = None
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'') exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try: try:
exif_comment = piexif.helper.UserComment.load(exif_comment) exif_comment = piexif.helper.UserComment.load(exif_comment)
@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
if exif_comment: if exif_comment:
items['exif comment'] = exif_comment items['exif comment'] = exif_comment
geninfo = exif_comment geninfo = exif_comment
elif "comment" in items: # for gif
geninfo = items["comment"].decode('utf8', errors="ignore")
for field in IGNORED_INFO_KEYS: for field in IGNORED_INFO_KEYS:
items.pop(field, None) items.pop(field, None)

View File

@ -10,6 +10,7 @@ from modules import images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
from modules.sd_models import get_closet_checkpoint_match
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
@ -41,7 +42,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
cfg_scale = p.cfg_scale cfg_scale = p.cfg_scale
sampler_name = p.sampler_name sampler_name = p.sampler_name
steps = p.steps steps = p.steps
override_settings = p.override_settings
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
batch_results = None
discard_further_results = False
for i, image in enumerate(images): for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(images)}"
if state.skipped: if state.skipped:
@ -104,16 +108,42 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
p.sampler_name = parsed_parameters.get("Sampler", sampler_name) p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
p.steps = int(parsed_parameters.get("Steps", steps)) p.steps = int(parsed_parameters.get("Steps", steps))
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
if model_info is not None:
p.override_settings['sd_model_checkpoint'] = model_info.name
elif sd_model_checkpoint_override:
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
else:
p.override_settings.pop("sd_model_checkpoint", None)
if output_dir:
p.outpath_samples = output_dir
p.override_settings['save_to_dirs'] = False
p.override_settings['save_images_replace_action'] = "Add number suffix"
if p.n_iter > 1 or p.batch_size > 1:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
else:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None: if proc is None:
if output_dir: p.override_settings.pop('save_images_replace_action', None)
p.outpath_samples = output_dir proc = process_images(p)
p.override_settings['save_to_dirs'] = False
if p.n_iter > 1 or p.batch_size > 1: if not discard_further_results and proc:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' if batch_results:
else: batch_results.images.extend(proc.images)
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' batch_results.infotexts.extend(proc.infotexts)
process_images(p) else:
batch_results = proc
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
discard_further_results = True
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
return batch_results
def img2img(id_task: str, def img2img(id_task: str,
@ -232,7 +262,7 @@ def img2img(id_task: str,
p.user = request.username p.user = request.username
if shared.cmd_opts.enable_console_prompts: if shared.opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask: if mask:
@ -244,10 +274,10 @@ def img2img(id_task: str,
with closing(p): with closing(p):
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) if processed is None:
processed = Processed(p, [], p.seed, "")
processed = Processed(p, [], p.seed, "")
else: else:
processed = modules.scripts.scripts_img2img.run(p, *args) processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None: if processed is None:

View File

@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
from modules import devices from modules import devices
devices.first_time_calculation() devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start() Thread(target=load_model).start()
from modules import shared_items from modules import shared_items
shared_items.reload_hypernetworks() shared_items.reload_hypernetworks()

View File

@ -150,10 +150,14 @@ def dumpstacks():
def configure_sigint_handler(): def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
from modules import shared
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
dumpstacks() if shared.opts.dump_stacks_on_signal:
dumpstacks()
os._exit(0) os._exit(0)

View File

@ -64,7 +64,7 @@ Use --skip-python-version-check to suppress this warning.
@lru_cache() @lru_cache()
def commit_hash(): def commit_hash():
try: try:
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip() return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
return "<none>" return "<none>"
@ -72,7 +72,7 @@ def commit_hash():
@lru_cache() @lru_cache()
def git_tag(): def git_tag():
try: try:
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception: except Exception:
try: try:
@ -441,7 +441,7 @@ def dump_sysinfo():
import datetime import datetime
text = sysinfo.get() text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
file.write(text) file.write(text)

View File

@ -14,21 +14,24 @@ def list_localizations(dirname):
if ext.lower() != ".json": if ext.lower() != ".json":
continue continue
localizations[fn] = os.path.join(dirname, file) localizations[fn] = [os.path.join(dirname, file)]
for file in scripts.list_scripts("localizations", ".json"): for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename) fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path if fn not in localizations:
localizations[fn] = []
localizations[fn].append(file.path)
def localization_js(current_localization_name: str) -> str: def localization_js(current_localization_name: str) -> str:
fn = localizations.get(current_localization_name, None) fns = localizations.get(current_localization_name, None)
data = {} data = {}
if fn is not None: if fns is not None:
try: for fn in fns:
with open(fn, "r", encoding="utf8") as file: try:
data = json.load(file) with open(fn, "r", encoding="utf8") as file:
except Exception: data.update(json.load(file))
errors.report(f"Error loading localization from {fn}", exc_info=True) except Exception:
errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}" return f"window.localization = {json.dumps(data)}"

View File

@ -1,16 +1,41 @@
import os import os
import logging import logging
try:
from tqdm.auto import tqdm
class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
except Exception:
self.handleError(record)
TQDM_IMPORTED = True
except ImportError:
# tqdm does not exist before first launch
# I will import once the UI finishes seting up the enviroment and reloads.
TQDM_IMPORTED = False
def setup_logging(loglevel): def setup_logging(loglevel):
if loglevel is None: if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
loghandlers = []
if TQDM_IMPORTED:
loghandlers.append(TqdmLoggingHandler())
if loglevel: if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig( logging.basicConfig(
level=log_level, level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s', format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', datefmt='%Y-%m-%d %H:%M:%S',
handlers=loghandlers
) )

View File

@ -1,5 +1,6 @@
import json import json
import sys import sys
from dataclasses import dataclass
import gradio as gr import gradio as gr
@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = section self.section = section
self.category_id = category_id
self.refresh = refresh self.refresh = refresh
self.do_not_save = False self.do_not_save = False
@ -63,7 +65,11 @@ class OptionHTML(OptionInfo):
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
for v in options_dict.values(): for v in options_dict.values():
v.section = section_identifier if len(section_identifier) == 2:
v.section = section_identifier
elif len(section_identifier) == 3:
v.section = section_identifier[0:2]
v.category_id = section_identifier[2]
return options_dict return options_dict
@ -76,7 +82,7 @@ class Options:
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
self.data_labels = data_labels self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()} self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
self.restricted_opts = restricted_opts self.restricted_opts = restricted_opts
def __setattr__(self, key, value): def __setattr__(self, key, value):
@ -158,7 +164,7 @@ class Options:
assert not cmd_opts.freeze_settings, "saving settings is disabled" assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file: with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4) json.dump(self.data, file, indent=4, ensure_ascii=False)
def same_type(self, x, y): def same_type(self, x, y):
if x is None or y is None: if x is None or y is None:
@ -206,21 +212,59 @@ class Options:
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
item_categories = {}
for item in self.data_labels.values():
category = categories.mapping.get(item.category_id)
category = "Uncategorized" if category is None else category.label
if category not in item_categories:
item_categories[category] = item.section[1]
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
return json.dumps(d) return json.dumps(d)
def add_option(self, key, info): def add_option(self, key, info):
self.data_labels[key] = info self.data_labels[key] = info
if key not in self.data and not info.do_not_save:
self.data[key] = info.default
def reorder(self): def reorder(self):
"""reorder settings so that all items related to section always go together""" """Reorder settings so that:
- all items related to section always go together
- all sections belonging to a category go together
- sections inside a category are ordered alphabetically
- categories are ordered by creation order
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
This function also changes items' category_id so that all items belonging to a section have the same category_id.
"""
category_ids = {}
section_categories = {}
section_ids = {}
settings_items = self.data_labels.items() settings_items = self.data_labels.items()
for _, item in settings_items: for _, item in settings_items:
if item.section not in section_ids: if item.section not in section_categories:
section_ids[item.section] = len(section_ids) section_categories[item.section] = item.category_id
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) for _, item in settings_items:
item.category_id = section_categories.get(item.section)
for category_id in categories.mapping:
if category_id not in category_ids:
category_ids[category_id] = len(category_ids)
def sort_key(x):
item: OptionInfo = x[1]
category_order = category_ids.get(item.category_id, len(category_ids))
section_order = item.section[1]
return category_order, section_order
self.data_labels = dict(sorted(settings_items, key=sort_key))
def cast_value(self, key, value): def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key """casts an arbitrary to the same type as this setting's value with key
@ -243,3 +287,22 @@ class Options:
value = expected_type(value) value = expected_type(value)
return value return value
@dataclass
class OptionsCategory:
id: str
label: str
class OptionsCategories:
def __init__(self):
self.mapping = {}
def register_category(self, category_id, label):
if category_id in self.mapping:
return category_id
self.mapping[category_id] = OptionsCategory(category_id, label)
categories = OptionsCategories()

View File

@ -1,6 +1,6 @@
import os import os
import sys import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
import modules.safe # noqa: F401 import modules.safe # noqa: F401

View File

@ -8,6 +8,7 @@ import shlex
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
cwd = os.getcwd()
modules_path = os.path.dirname(os.path.realpath(__file__)) modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path) script_path = os.path.dirname(modules_path)

View File

@ -78,7 +78,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
image_data.close() image_data.close()
devices.torch_gc() devices.torch_gc()
shared.state.end()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''

View File

@ -149,7 +149,7 @@ class StableDiffusionProcessing:
masks_for_overlay: list = None masks_for_overlay: list = None
eta: float = None eta: float = None
do_not_reload_embeddings: bool = False do_not_reload_embeddings: bool = False
denoising_strength: float = 0 denoising_strength: float = None
ddim_discretize: str = None ddim_discretize: str = None
s_min_uncond: float = None s_min_uncond: float = None
s_churn: float = None s_churn: float = None
@ -303,7 +303,7 @@ class StableDiffusionProcessing:
return conditioning return conditioning
def edit_image_conditioning(self, source_image): def edit_image_conditioning(self, source_image):
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
return conditioning_image return conditioning_image
@ -537,6 +537,7 @@ class Processed:
self.all_seeds = all_seeds or p.all_seeds or [self.seed] self.all_seeds = all_seeds or p.all_seeds or [self.seed]
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
self.infotexts = infotexts or [info] self.infotexts = infotexts or [info]
self.version = program_version()
def js(self): def js(self):
obj = { obj = {
@ -571,6 +572,7 @@ class Processed:
"job_timestamp": self.job_timestamp, "job_timestamp": self.job_timestamp,
"clip_skip": self.clip_skip, "clip_skip": self.clip_skip,
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning, "is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
"version": self.version,
} }
return json.dumps(obj) return json.dumps(obj)
@ -713,7 +715,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process(p) p.scripts.before_process(p)
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
@ -801,7 +803,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@ -876,7 +877,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
# Generate the mask(s) based on similarity between the original and denoised latent vectors # Generate the mask(s) based on similarity between the original and denoised latent vectors
if getattr(p, "image_mask", None) is not None: if getattr(p, "image_mask", None) is not None:
# latent_mask = p.nmask[0].float().cpu() # latent_mask = p.nmask[0].float().cpu()
@ -943,6 +943,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob()
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
@ -1025,7 +1027,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
state.nextjob() if not infotexts:
infotexts.append(Processed(p, []).infotext(p, 0))
p.color_corrections = None p.color_corrections = None
@ -1211,6 +1214,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if not self.enable_hr: if not self.enable_hr:
return samples return samples
devices.torch_gc()
if self.latent_scale_mode is None: if self.latent_scale_mode is None:
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
@ -1220,8 +1224,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info) sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@ -1229,7 +1231,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples return samples
self.is_hr_pass = True self.is_hr_pass = True
target_width = self.hr_upscale_to_x target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y target_height = self.hr_upscale_to_y
@ -1318,7 +1319,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
return decoded_samples return decoded_samples
def close(self): def close(self):

View File

@ -29,8 +29,8 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
else: else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0) self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed') random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed') reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False) seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)

View File

@ -2,10 +2,9 @@ from __future__ import annotations
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List
import lark import lark
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
# will be represented with prompt_schedule like this (assuming steps=100): # will be represented with prompt_schedule like this (assuming steps=100):
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
class ComposableScheduledPromptConditioning: class ComposableScheduledPromptConditioning:
def __init__(self, schedules, weight=1.0): def __init__(self, schedules, weight=1.0):
self.schedules: List[ScheduledPromptConditioning] = schedules self.schedules: list[ScheduledPromptConditioning] = schedules
self.weight: float = weight self.weight: float = weight
class MulticondLearnedConditioning: class MulticondLearnedConditioning:
def __init__(self, shape, batch): def __init__(self, shape, batch):
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
@ -278,7 +277,7 @@ class DictWithShape(dict):
return self["crossattn"].shape return self["crossattn"].shape
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond param = c[0][0].cond
is_dict = isinstance(param, dict) is_dict = isinstance(param, dict)

View File

@ -14,7 +14,9 @@ def is_restartable() -> bool:
def restart_program() -> None: def restart_program() -> None:
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
(Path(script_path) / "tmp" / "restart").touch() tmpdir = Path(script_path) / "tmp"
tmpdir.mkdir(parents=True, exist_ok=True)
(tmpdir / "restart").touch()
stop_program() stop_program()

View File

@ -110,7 +110,7 @@ class ImageRNG:
self.is_first = True self.is_first = True
def first(self): def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = [] xs = []

View File

@ -1,7 +1,7 @@
import inspect import inspect
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Optional, Dict, Any from typing import Optional, Any
from fastapi import FastAPI from fastapi import FastAPI
from gradio import Blocks from gradio import Blocks
@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams):
report_exception(c, 'image_grid') report_exception(c, 'image_grid')
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']: for c in callback_map['callbacks_infotext_pasted']:
try: try:
c.callback(infotext, params) c.callback(infotext, params)
@ -449,7 +449,7 @@ def on_infotext_pasted(callback):
"""register a function to be called before applying an infotext. """register a function to be called before applying an infotext.
The callback is called with two arguments: The callback is called with two arguments:
- infotext: str - raw infotext. - infotext: str - raw infotext.
- result: Dict[str, any] - parsed infotext parameters. - result: dict[str, any] - parsed infotext parameters.
""" """
add_callback(callback_map['callbacks_infotext_pasted'], callback) add_callback(callback_map['callbacks_infotext_pasted'], callback)

View File

@ -311,20 +311,113 @@ scripts_data = []
postprocessing_scripts_data = [] postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def topological_sort(dependencies):
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
Ignores errors relating to missing dependeencies or circular dependencies
"""
visited = {}
result = []
def inner(name):
visited[name] = True
for dep in dependencies.get(name, []):
if dep in dependencies and dep not in visited:
inner(dep)
result.append(name)
for depname in dependencies:
if depname not in visited:
inner(depname)
return result
@dataclass
class ScriptWithDependencies:
script_canonical_name: str
file: ScriptFile
requires: list
load_before: list
load_after: list
def list_scripts(scriptdirname, extension, *, include_extensions=True): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts = {}
basedir = os.path.join(paths.script_path, scriptdirname) loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
if os.path.exists(basedir): loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) # build script dependency map
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
if os.path.splitext(filename)[1].lower() != extension:
continue
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
if include_extensions: if include_extensions:
for ext in extensions.active(): for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension) extension_scripts_list = ext.list_files(scriptdirname, extension)
for extension_script in extension_scripts_list:
if not os.path.isfile(extension_script.path):
continue
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
relative_path = scriptdirname + "/" + extension_script.filename
script = ScriptWithDependencies(
script_canonical_name=script_canonical_name,
file=extension_script,
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
)
scripts[script_canonical_name] = script
loaded_extensions_scripts[ext.canonical_name].append(script)
for script_canonical_name, script in scripts.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before in script.load_before:
# if this requires an individual script to be loaded before
other_script = scripts.get(load_before)
if other_script:
other_script.load_after.append(script_canonical_name)
# if this requires an extension
other_extension_scripts = loaded_extensions_scripts.get(load_before)
if other_extension_scripts:
for other_script in other_extension_scripts:
other_script.load_after.append(script_canonical_name)
# if After mentions an extension, remove it and instead add all of its scripts
for load_after in list(script.load_after):
if load_after not in scripts and load_after in loaded_extensions_scripts:
script.load_after.remove(load_after)
for other_script in loaded_extensions_scripts.get(load_after, []):
script.load_after.append(other_script.script_canonical_name)
dependencies = {}
for script_canonical_name, script in scripts.items():
for required_script in script.requires:
if required_script not in scripts and required_script not in loaded_extensions:
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
dependencies[script_canonical_name] = script.load_after
ordered_scripts = topological_sort(dependencies)
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
return scripts_list return scripts_list
@ -365,15 +458,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
def orderby(basedir): # here the scripts_list is already ordered
# 1st webui, 2nd extensions-builtin, 3rd extensions # processing_script is not considered though
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} for scriptfile in scripts_list:
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
@ -491,11 +578,15 @@ class ScriptRunner:
arg_info = api_models.ScriptArg(label=control.label or "") arg_info = api_models.ScriptArg(label=control.label or "")
for field in ("value", "minimum", "maximum", "step", "choices"): for field in ("value", "minimum", "maximum", "step"):
v = getattr(control, field, None) v = getattr(control, field, None)
if v is not None: if v is not None:
setattr(arg_info, field, v) setattr(arg_info, field, v)
choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
if choices is not None:
arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
api_args.append(arg_info) api_args.append(arg_info)
script.api_info = api_models.ScriptInfo( script.api_info = api_models.ScriptInfo(

View File

@ -2,14 +2,15 @@ import torch
from torch.nn.functional import silu from torch.nn.functional import silu
from types import MethodType from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.encoders.modules import ldm.modules.encoders.modules
@ -37,6 +38,8 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
def list_optimizers(): def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback() new_optimizers = script_callbacks.list_optimizers_callback()
@ -181,6 +184,20 @@ class StableDiffusionModelHijack:
errors.display(e, "applying cross attention optimization") errors.display(e, "applying cross attention optimization")
undo_optimizations() undo_optimizations()
def convert_sdxl_to_ssd(self, m):
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
delattr(m.model.diffusion_model.middle_block, '1')
delattr(m.model.diffusion_model.middle_block, '2')
for i in ['9', '8', '7', '6', '5', '4']:
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
devices.torch_gc()
def hijack(self, m): def hijack(self, m):
conditioner = getattr(m, 'conditioner', None) conditioner = getattr(m, 'conditioner', None)
if conditioner: if conditioner:
@ -208,7 +225,7 @@ class StableDiffusionModelHijack:
else: else:
m.cond_stage_model = conditioner m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
@ -239,10 +256,17 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): import modules.models.diffusion.ddpm_edit
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
sd_unet.original_forward = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m): def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None) conditioner = getattr(m, 'conditioner', None)
@ -279,7 +303,8 @@ class StableDiffusionModelHijack:
self.layers = None self.layers = None
self.clip = None self.clip = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui sd_unet.original_forward = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:

View File

@ -1,22 +1,22 @@
import collections import collections
import os.path import os.path
import sys import sys
import gc
import threading import threading
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf, ListConfig
from os import mkdir from os import mkdir
from urllib import request from urllib import request
import ldm.modules.midas as midas import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
import numpy as np
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@ -49,11 +49,12 @@ class CheckpointInfo:
def __init__(self, filename): def __init__(self, filename):
self.filename = filename self.filename = filename
abspath = os.path.abspath(filename) abspath = os.path.abspath(filename)
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '') name = abspath.replace(abs_ckpt_dir, '')
elif abspath.startswith(model_path): elif abspath.startswith(model_path):
name = abspath.replace(model_path, '') name = abspath.replace(model_path, '')
else: else:
@ -129,9 +130,12 @@ except Exception:
def setup_model(): def setup_model():
"""called once at startup to do various one-time tasks related to SD models"""
os.makedirs(model_path, exist_ok=True) os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload() enable_midas_autodownload()
patch_given_betas()
def checkpoint_tiles(use_short=False): def checkpoint_tiles(use_short=False):
@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded: if checkpoint_info in checkpoints_loaded:
# use checkpoint cache # use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache") print(f"Loading weights [{sd_model_hash}] from cache")
# move to end as latest
checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info] return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@ -346,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner') model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl: if model.is_sdxl:
sd_models_xl.extend_sdxl(model) sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False) if model.is_ssd:
timer.record("apply weights to model") sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0: if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict checkpoints_loaded[checkpoint_info] = state_dict.copy()
model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model")
del state_dict del state_dict
@ -453,6 +462,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper midas.api.load_model = load_model_wrapper
def patch_given_betas():
import ldm.models.diffusion.ddpm
def patched_register_schedule(*args, **kwargs):
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
if isinstance(args[1], ListConfig):
args = (args[0], np.array(args[1]), *args[2:])
original_register_schedule(*args, **kwargs)
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config): def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
@ -777,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
timer = Timer() send_model_to_cpu(sd_model or shared.sd_model)
if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
print(f"Unloaded weights {timer.summary()}.")
return sd_model return sd_model

View File

@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict): def is_using_v_parameterization_for_sd2(state_dict):
""" """
@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8: if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18
return config_alt_diffusion return config_alt_diffusion
return config_default return config_default

View File

@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
"""structure with additional information about the file with model's weights""" """structure with additional information about the file with model's weights"""
is_sdxl: bool is_sdxl: bool
"""True if the model's architecture is SDXL""" """True if the model's architecture is SDXL or SSD"""
is_ssd: bool
"""True if the model is SSD"""
is_sd2: bool is_sd2: bool
"""True if the model's architecture is SD 2.x""" """True if the model's architecture is SD 2.x"""

View File

@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0: while restart_times > 0:
restart_times -= 1 restart_times -= 1
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
last_sigma = None last_sigma = None
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):

View File

@ -1,11 +1,11 @@
import torch.nn import torch.nn
import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices from modules import script_callbacks, shared, devices
unet_options = [] unet_options = []
current_unet_option = None current_unet_option = None
current_unet = None current_unet = None
original_forward = None
def list_unets(): def list_unets():
@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None: if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs) return current_unet.forward(x, timesteps, context, *args, **kwargs)
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) return original_forward(self, x, timesteps, context, *args, **kwargs)

View File

@ -14,5 +14,5 @@ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
else: else:
cmd_opts, _ = parser.parse_known_args() cmd_opts, _ = parser.parse_known_args()
cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access

View File

@ -44,9 +44,9 @@ def refresh_unet_list():
modules.sd_unet.list_unets() modules.sd_unet.list_unets()
def list_checkpoint_tiles(): def list_checkpoint_tiles(use_short=False):
import modules.sd_models import modules.sd_models
return modules.sd_models.checkpoint_tiles() return modules.sd_models.checkpoint_tiles(use_short)
def refresh_checkpoints(): def refresh_checkpoints():
@ -67,6 +67,8 @@ def reload_hypernetworks():
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"prompt",
"image",
"inpaint", "inpaint",
"sampler", "sampler",
"accordions", "accordions",

View File

@ -3,7 +3,7 @@ import gradio as gr
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML from modules.options import options_section, OptionInfo, OptionHTML, categories
options_templates = {} options_templates = {}
hide_dirs = shared.hide_dirs hide_dirs = shared.hide_dirs
@ -21,12 +21,19 @@ restricted_opts = {
"outdir_init_images" "outdir_init_images"
} }
options_templates.update(options_section(('saving-images', "Saving images/grids"), { categories.register_category("saving", "Saving images")
categories.register_category("sd", "Stable Diffusion")
categories.register_category("ui", "User Interface")
categories.register_category("system", "System")
categories.register_category("postprocessing", "Postprocessing")
categories.register_category("training", "Training")
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
"samples_save": OptionInfo(True, "Always save all generated images"), "samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'), "samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
"grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_save": OptionInfo(True, "Always save all generated image grids"),
"grid_format": OptionInfo('png', 'File format for grids'), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
@ -62,9 +69,12 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
"notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
@ -76,7 +86,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
})) }))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
@ -84,22 +94,23 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
})) }))
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"), "face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System", "system"), {
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
@ -109,15 +120,16 @@ options_templates.update(options_section(('system', "System"), {
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
})) }))
options_templates.update(options_section(('API', "API"), { options_templates.update(options_section(('API', "API", "system"), {
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True), "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training", "training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
@ -132,8 +144,8 @@ options_templates.update(options_section(('training', "Training"), {
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
@ -149,14 +161,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"), "hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"), "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"), "sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
})) }))
options_templates.update(options_section(('vae', "VAE"), { options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML(""" "sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr> <abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
@ -171,7 +183,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
})) }))
options_templates.update(options_section(('img2img', "img2img"), { options_templates.update(options_section(('img2img', "img2img", "sd"), {
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"), "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
@ -184,9 +196,10 @@ options_templates.update(options_section(('img2img', "img2img"), {
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
})) }))
options_templates.update(options_section(('optimizations', "Optimizations"), { options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
@ -197,7 +210,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
@ -222,7 +235,7 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
})) }))
options_templates.update(options_section(('extra_networks', "Extra Networks"), { options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
@ -230,6 +243,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
@ -237,7 +252,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface", "ui"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
@ -255,19 +270,24 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
"sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
})) }))
options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('infotext', "Infotext", "ui"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
@ -282,7 +302,7 @@ options_templates.update(options_section(('infotext', "Infotext"), {
})) }))
options_templates.update(options_section(('ui', "Live previews"), { options_templates.update(options_section(('ui', "Live previews", "ui"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
@ -295,7 +315,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
@ -305,8 +325,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
@ -317,7 +337,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
})) }))
options_templates.update(options_section(('postprocessing', "Postprocessing"), { options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
@ -329,4 +349,3 @@ options_templates.update(options_section((None, "Hidden options"), {
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))

View File

@ -103,6 +103,7 @@ class State:
def begin(self, job: str = "(unknown)"): def begin(self, job: str = "(unknown)"):
self.sampling_step = 0 self.sampling_step = 0
self.time_start = time.time()
self.job_count = -1 self.job_count = -1
self.processing_has_refined_job_count = False self.processing_has_refined_job_count = False
self.job_no = 0 self.job_no = 0
@ -114,7 +115,6 @@ class State:
self.skipped = False self.skipped = False
self.interrupted = False self.interrupted = False
self.textinfo = None self.textinfo = None
self.time_start = time.time()
self.job = job self.job = job
devices.torch_gc() devices.torch_gc()
log.info("Starting job %s", job) log.info("Starting job %s", job)

View File

@ -15,7 +15,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
import math import math
from typing import Optional, NamedTuple, List from typing import Optional, NamedTuple
def narrow_trunc( def narrow_trunc(
@ -97,7 +97,7 @@ def _query_chunk_attention(
) )
return summarize_chunk(query, key_chunk, value_chunk) return summarize_chunk(query, key_chunk, value_chunk)
chunks: List[AttnChunk] = [ chunks: list[AttnChunk] = [
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
] ]
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))

View File

@ -1,7 +1,6 @@
import json import json
import os import os
import sys import sys
import traceback
import platform import platform
import hashlib import hashlib
@ -84,7 +83,7 @@ def get_dict():
"Checksum": checksum_token, "Checksum": checksum_token,
"Commandline": get_argv(), "Commandline": get_argv(),
"Torch env info": get_torch_sysinfo(), "Torch env info": get_torch_sysinfo(),
"Exceptions": get_exceptions(), "Exceptions": errors.get_exceptions(),
"CPU": { "CPU": {
"model": platform.processor(), "model": platform.processor(),
"count logical": psutil.cpu_count(logical=True), "count logical": psutil.cpu_count(logical=True),
@ -104,21 +103,6 @@ def get_dict():
return res return res
def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions():
try:
return list(reversed(errors.exception_records))
except Exception as e:
return str(e)
def get_environment(): def get_environment():
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}

View File

@ -181,40 +181,7 @@ class EmbeddingDatabase:
else: else:
return return
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape
embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
if self.expected_shape == -1 or self.expected_shape == embedding.shape: if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)
@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
return fn return fn
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
if 'string_to_param' in data: # textual inversion embeddings
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape
if filepath:
embedding.filename = filepath
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
return embedding
def write_loss(log_directory, filename, step, epoch_len, values): def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
from modules import processing from modules import processing
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width

View File

@ -3,7 +3,7 @@ from contextlib import closing
import modules.scripts import modules.scripts
from modules import processing from modules import processing
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts from modules.shared import opts
import modules.shared as shared import modules.shared as shared
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.user = request.username p.user = request.username
if cmd_opts.enable_console_prompts: if shared.opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
with closing(p): with closing(p):

View File

@ -4,6 +4,7 @@ import os
import sys import sys
from functools import reduce from functools import reduce
import warnings import warnings
from contextlib import ExitStack
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401 from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path from modules.paths import script_path
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
@ -25,7 +26,6 @@ import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared import modules.shared as shared
import modules.images
from modules import prompt_parser from modules import prompt_parser
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
@ -151,11 +151,15 @@ def connect_clear_prompt(button):
) )
def update_token_counter(text, steps): def update_token_counter(text, steps, *, is_positive=True):
try: try:
text, _ = extra_networks.parse_prompt(text) text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) if is_positive:
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
else:
prompt_flat_list = [text]
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception: except Exception:
@ -169,76 +173,9 @@ def update_token_counter(text, steps):
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>" return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
class Toprow: def update_negative_prompt_token_counter(text, steps):
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" return update_token_counter(text, steps, is_positive=False)
def __init__(self, is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
self.id_part = id_part
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.button_interrogate = None
self.button_deepbooru = None
if is_img2img:
with gr.Column(scale=1, elem_classes="interrogate-col"):
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
self.skip.click(
fn=lambda: shared.state.skip(),
inputs=[],
outputs=[],
)
self.interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
with gr.Row(elem_id=f"{id_part}_tools"):
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
self.clear_prompt_button.click(
fn=lambda *x: x,
_js="confirm_clear_prompt",
inputs=[self.prompt, self.negative_prompt],
outputs=[self.prompt, self.negative_prompt],
)
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
self.prompt_img.change(
fn=modules.images.image_data,
inputs=[self.prompt_img],
outputs=[self.prompt, self.prompt_img],
show_progress=False,
)
def setup_progressbar(*args, **kwargs): def setup_progressbar(*args, **kwargs):
@ -278,8 +215,8 @@ def apply_setting(key, value):
return getattr(opts, key) return getattr(opts, key)
def create_output_panel(tabname, outdir): def create_output_panel(tabname, outdir, toprow=None):
return ui_common.create_output_panel(tabname, outdir) return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname): def create_sampler_and_steps_selection(choices, tabname):
@ -326,7 +263,7 @@ def create_ui():
scripts.scripts_txt2img.initialize_scripts(is_img2img=False) scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
toprow = Toprow(is_img2img=False) toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -334,10 +271,17 @@ def create_ui():
extra_tabs.__enter__() extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False): with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"): with ExitStack() as stack:
if shared.opts.txt2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
scripts.scripts_txt2img.prepare_ui() scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories(): for category in ordered_ui_categories():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "sampler": if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
@ -348,7 +292,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims") res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
if opts.dimensions_and_batch_together: if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"): with gr.Column(elem_id="txt2img_column_batch"):
@ -432,7 +376,7 @@ def create_ui():
show_progress=False, show_progress=False,
) )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
txt2img_args = dict( txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
@ -533,7 +477,7 @@ def create_ui():
] ]
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@ -544,13 +488,17 @@ def create_ui():
scripts.scripts_img2img.initialize_scripts(is_img2img=True) scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
toprow = Toprow(is_img2img=True) toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__() extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False): with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"): with ExitStack() as stack:
if shared.opts.img2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
copy_image_buttons = [] copy_image_buttons = []
copy_image_destinations = {} copy_image_destinations = {}
@ -567,85 +515,89 @@ def create_ui():
button = gr.Button(title) button = gr.Button(title)
copy_image_buttons.append((button, name, elem)) copy_image_buttons.append((button, name, elem))
with gr.Tabs(elem_id="mode_img2img"):
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
same_size = state is not None and state.size == image.size
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
edited = same_size and has_exact_match
return image if not edited or state is None else state
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}</p>"
)
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
with gr.Accordion("PNG info", open=False):
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
for i, tab in enumerate(img2img_tabs):
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
def copy_image(img):
if isinstance(img, dict) and 'image' in img:
return img['image']
return img
for button, name, elem in copy_image_buttons:
button.click(
fn=copy_image,
inputs=[elem],
outputs=[copy_image_destinations[name]],
)
button.click(
fn=lambda: None,
_js=f"switch_to_{name.replace(' ', '_')}",
inputs=[],
outputs=[],
)
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
scripts.scripts_img2img.prepare_ui() scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories(): for category in ordered_ui_categories():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "image":
with gr.Tabs(elem_id="mode_img2img"):
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
def update_orig(image, state):
if image is not None:
same_size = state is not None and state.size == image.size
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
edited = same_size and has_exact_match
return image if not edited or state is None else state
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}</p>"
)
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
with gr.Accordion("PNG info", open=False):
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
for i, tab in enumerate(img2img_tabs):
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
def copy_image(img):
if isinstance(img, dict) and 'image' in img:
return img['image']
return img
for button, name, elem in copy_image_buttons:
button.click(
fn=copy_image,
inputs=[elem],
outputs=[copy_image_destinations[name]],
)
button.click(
fn=lambda: None,
_js=f"switch_to_{name.replace(' ', '_')}",
inputs=[],
outputs=[],
)
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
if category == "sampler": if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
@ -661,8 +613,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"): with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
@ -683,12 +635,6 @@ def create_ui():
scale_by.release(**on_change_args) scale_by.release(**on_change_args)
button_update_resize_to.click(**on_change_args) button_update_resize_to.click(**on_change_args)
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
for component in [init_img, sketch]:
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab]) tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab]) tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
@ -749,20 +695,26 @@ def create_ui():
with gr.Column(scale=4): with gr.Column(scale=4):
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
def select_img2img_tab(tab):
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
for i, elem in enumerate(img2img_tabs):
elem.select(
fn=lambda tab=i: select_img2img_tab(tab),
inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
if category not in {"accordions"}: if category not in {"accordions"}:
scripts.scripts_img2img.setup_ui_for_section(category) scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) # the code below is meant to update the resolution label after the image in the image selection UI has changed.
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
for component in [init_img, sketch]:
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
def select_img2img_tab(tab):
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
for i, elem in enumerate(img2img_tabs):
elem.select(
fn=lambda tab=i: select_img2img_tab(tab),
inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
img2img_args = dict( img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
@ -1295,7 +1247,7 @@ def create_ui():
loadsave.setup_ui() loadsave.setup_ui()
if os.path.exists(os.path.join(script_path, "notification.mp3")): if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html") footer = shared.html("footer.html")
@ -1347,7 +1299,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
def setup_ui_api(app): def setup_ui_api(app):
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List
class QuicksettingsHint(BaseModel): class QuicksettingsHint(BaseModel):
name: str = Field(title="Name of the quicksettings field") name: str = Field(title="Name of the quicksettings field")
@ -1356,7 +1307,7 @@ def setup_ui_api(app):
def quicksettings_hint(): def quicksettings_hint():
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
@ -1366,7 +1317,7 @@ def setup_ui_api(app):
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
text = sysinfo.get() text = sysinfo.get()
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})

View File

@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
def create_output_panel(tabname, outdir): def create_output_panel(tabname, outdir, toprow=None):
def open_folder(f): def open_folder(f):
if not os.path.exists(f): if not os.path.exists(f):
@ -130,12 +130,15 @@ Requested path was: {f}
else: else:
sp.Popen(["xdg-open", path]) sp.Popen(["xdg-open", path])
with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Column(elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"): if toprow:
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) toprow.create_inline_toprow_image()
generation_info = None with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
with gr.Column(): with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
generation_info = None
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")

View File

@ -65,7 +65,7 @@ def save_config_state(name):
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.") print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f, indent=4) json.dump(current_config_state, f, indent=4, ensure_ascii=False)
config_states.list_config_states() config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current") new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys()) new_choices = ["Current"] + list(config_states.all_config_states.keys())
@ -197,7 +197,7 @@ def update_config_states_table(state_name):
config_state = config_states.all_config_states[state_name] config_state = config_states.all_config_states[state_name]
config_name = config_state.get("name", "Config") config_name = config_state.get("name", "Config")
created_date = time.asctime(time.gmtime(config_state["created_at"])) created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
filepath = config_state.get("filepath", "<unknown>") filepath = config_state.get("filepath", "<unknown>")
try: try:

View File

@ -1,3 +1,4 @@
import functools
import os.path import os.path
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
@ -15,6 +16,17 @@ from modules.ui_components import ToolButton
extra_pages = [] extra_pages = []
allowed_dirs = set() allowed_dirs = set()
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
@functools.cache
def allowed_preview_extensions_with_extra(extra_extensions=None):
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
def allowed_preview_extensions():
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
def register_page(page): def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
@ -33,9 +45,9 @@ def fetch_file(filename: str = ""):
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower() ext = os.path.splitext(filename)[1].lower()[1:]
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): if ext not in allowed_preview_extensions():
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
# would profit from returning 304 # would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@ -91,6 +103,7 @@ class ExtraNetworksPage:
self.name = title.lower() self.name = title.lower()
self.id_page = self.name.replace(" ", "_") self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html") self.card_page = shared.html("extra-networks-card.html")
self.allow_prompt = True
self.allow_negative_prompt = False self.allow_negative_prompt = False
self.metadata = {} self.metadata = {}
self.items = {} self.items = {}
@ -213,9 +226,9 @@ class ExtraNetworksPage:
metadata_button = "" metadata_button = ""
metadata = item.get("metadata") metadata = item.get("metadata")
if metadata: if metadata:
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>" metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>" edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
local_path = "" local_path = ""
filename = item.get("filename", "") filename = item.get("filename", "")
@ -235,7 +248,7 @@ class ExtraNetworksPage:
if search_only and shared.opts.extra_networks_hidden_models == "Never": if search_only and shared.opts.extra_networks_hidden_models == "Never":
return "" return ""
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip() sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
args = { args = {
"background_image": background_image, "background_image": background_image,
@ -266,6 +279,7 @@ class ExtraNetworksPage:
"date_created": int(stat.st_ctime or 0), "date_created": int(stat.st_ctime or 0),
"date_modified": int(stat.st_mtime or 0), "date_modified": int(stat.st_mtime or 0),
"name": pth.name.lower(), "name": pth.name.lower(),
"path": str(pth.parent).lower(),
} }
def find_preview(self, path): def find_preview(self, path):
@ -273,11 +287,7 @@ class ExtraNetworksPage:
Find a preview PNG for a given path (without extension) and call link_preview on it. Find a preview PNG for a given path (without extension) and call link_preview on it.
""" """
preview_extensions = ["png", "jpg", "jpeg", "webp"] potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
if shared.opts.samples_format not in preview_extensions:
preview_extensions.append(shared.opts.samples_format)
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
for file in potential_files: for file in potential_files:
if os.path.isfile(file): if os.path.isfile(file):
@ -359,7 +369,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs = [] related_tabs = []
for page in ui.stored_extra_pages: for page in ui.stored_extra_pages:
with gr.Tab(page.title, id=page.id_page) as tab: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
pass
elem_id = f"{tabname}_{page.id_page}_cards_html" elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id) page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem) ui.pages.append(page_elem)
@ -373,19 +386,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
related_tabs.append(tab) related_tabs.append(tab)
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
for tab in unrelated_tabs: tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
for tab in related_tabs: for tab in unrelated_tabs:
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
for page, tab in zip(ui.stored_extra_pages, related_tabs):
allow_prompt = "true" if page.allow_prompt else "false"
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
def pages_html(): def pages_html():
if not ui.pages_contents: if not ui.pages_contents:

View File

@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self): def __init__(self):
super().__init__('Checkpoints') super().__init__('Checkpoints')
self.allow_prompt = False
def refresh(self): def refresh(self):
shared.refresh_checkpoints() shared.refresh_checkpoints()
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
if checkpoint is None:
return
path, ext = os.path.splitext(checkpoint.filename) path, ext = os.path.splitext(checkpoint.filename)
return { return {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
# instantiate a list to protect against concurrent modification
names = list(sd_models.checkpoints_list) names = list(sd_models.checkpoints_list)
for index, name in enumerate(names): for index, name in enumerate(names):
yield self.create_item(name, index) item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]

View File

@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks() shared.reload_hypernetworks()
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name] full_path = shared.hypernetworks.get(name)
if full_path is None:
return
path, ext = os.path.splitext(full_path) path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}') sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None shorthash = sha256[0:10] if sha256 else None
@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
for index, name in enumerate(shared.hypernetworks): # instantiate a list to protect against concurrent modification
yield self.create_item(name, index) names = list(shared.hypernetworks)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir] return [shared.cmd_opts.hypernetwork_dir]

View File

@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
if embedding is None:
return
path, ext = os.path.splitext(embedding.filename) path, ext = os.path.splitext(embedding.filename)
return { return {
@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
} }
def list_items(self): def list_items(self):
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): # instantiate a list to protect against concurrent modification
yield self.create_item(name, index) names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
for index, name in enumerate(names):
item = self.create_item(name, index)
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)

View File

@ -134,7 +134,7 @@ class UserMetadataEditor:
basename, ext = os.path.splitext(filename) basename, ext = os.path.splitext(filename)
with open(basename + '.json', "w", encoding="utf8") as file: with open(basename + '.json', "w", encoding="utf8") as file:
json.dump(metadata, file, indent=4) json.dump(metadata, file, indent=4, ensure_ascii=False)
def save_user_metadata(self, name, desc, notes): def save_user_metadata(self, name, desc, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)

View File

@ -2,12 +2,12 @@ import os
import gradio as gr import gradio as gr
from modules import localization, shared, scripts from modules import localization, shared, scripts
from modules.paths import script_path, data_path from modules.paths import script_path, data_path, cwd
def webpath(fn): def webpath(fn):
if fn.startswith(script_path): if fn.startswith(cwd):
web_path = os.path.relpath(fn, script_path).replace('\\', '/') web_path = os.path.relpath(fn, cwd)
else: else:
web_path = os.path.abspath(fn) web_path = os.path.abspath(fn)

View File

@ -4,7 +4,7 @@ import os
import gradio as gr import gradio as gr
from modules import errors from modules import errors
from modules.ui_components import ToolButton from modules.ui_components import ToolButton, InputAccordion
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
@ -32,8 +32,6 @@ class UiLoadsave:
self.error_loading = True self.error_loading = True
errors.display(e, "loading settings") errors.display(e, "loading settings")
def add_component(self, path, x): def add_component(self, path, x):
"""adds component to the registry of tracked components""" """adds component to the registry of tracked components"""
@ -43,20 +41,24 @@ class UiLoadsave:
key = f"{path}/{field}" key = f"{path}/{field}"
if getattr(obj, 'custom_script_source', None) is not None: if getattr(obj, 'custom_script_source', None) is not None:
key = f"customscript/{obj.custom_script_source}/{key}" key = f"customscript/{obj.custom_script_source}/{key}"
if getattr(obj, 'do_not_save_to_config', False): if getattr(obj, 'do_not_save_to_config', False):
return return
saved_value = self.ui_settings.get(key, None) saved_value = self.ui_settings.get(key, None)
if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':
field = 'open'
if saved_value is None: if saved_value is None:
self.ui_settings[key] = getattr(obj, field) self.ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
pass pass
else: else:
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value) saved_value = str(saved_value)
elif isinstance(x, gr.Number) and field == 'value': elif isinstance(obj, gr.Number) and field == 'value':
try: try:
saved_value = float(saved_value) saved_value = float(saved_value)
except ValueError: except ValueError:
@ -67,7 +69,7 @@ class UiLoadsave:
init_field(saved_value) init_field(saved_value)
if field == 'value' and key not in self.component_mapping: if field == 'value' and key not in self.component_mapping:
self.component_mapping[key] = x self.component_mapping[key] = obj
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible: if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
apply_field(x, 'visible') apply_field(x, 'visible')
@ -100,6 +102,12 @@ class UiLoadsave:
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
if type(x) == InputAccordion:
if x.accordion.visible:
apply_field(x.accordion, 'visible')
apply_field(x, 'value')
apply_field(x.accordion, 'value')
def check_tab_id(tab_id): def check_tab_id(tab_id):
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children)) tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
if type(tab_id) == str: if type(tab_id) == str:
@ -133,7 +141,7 @@ class UiLoadsave:
def write_to_file(self, current_ui_settings): def write_to_file(self, current_ui_settings):
with open(self.filename, "w", encoding="utf8") as file: with open(self.filename, "w", encoding="utf8") as file:
json.dump(current_ui_settings, file, indent=4) json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
def dump_defaults(self): def dump_defaults(self):
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start""" """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""

View File

@ -4,6 +4,7 @@ from modules import shared, ui_common, ui_components, styles
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
styles_materialize_symbol = '\U0001f4cb' # 📋 styles_materialize_symbol = '\U0001f4cb' # 📋
styles_copy_symbol = '\U0001f4dd' # 📝
def select_style(name): def select_style(name):
@ -52,6 +53,8 @@ def refresh_styles():
class UiPromptStyles: class UiPromptStyles:
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
self.tabname = tabname self.tabname = tabname
self.main_ui_prompt = main_ui_prompt
self.main_ui_negative_prompt = main_ui_negative_prompt
with gr.Row(elem_id=f"{tabname}_styles_row"): with gr.Row(elem_id=f"{tabname}_styles_row"):
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
@ -61,13 +64,14 @@ class UiPromptStyles:
with gr.Row(): with gr.Row():
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
with gr.Row(): with gr.Row():
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
with gr.Row(): with gr.Row():
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
with gr.Row(): with gr.Row():
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
@ -96,15 +100,21 @@ class UiPromptStyles:
show_progress=False, show_progress=False,
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
self.materialize.click( self.setup_apply_button(self.materialize)
fn=materialize_styles,
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], self.copy.click(
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], fn=lambda p, n: (p, n),
inputs=[main_ui_prompt, main_ui_negative_prompt],
outputs=[self.prompt, self.neg_prompt],
show_progress=False, show_progress=False,
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) )
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
def setup_apply_button(self, button):
button.click(
fn=materialize_styles,
inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
show_progress=False,
).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False)

View File

@ -1,10 +1,11 @@
import gradio as gr import gradio as gr
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules.call_queue import wrap_gradio_call from modules.call_queue import wrap_gradio_call
from modules.shared import opts from modules.shared import opts
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
from concurrent.futures import ThreadPoolExecutor, as_completed
def get_value_for_setting(key): def get_value_for_setting(key):
@ -63,6 +64,9 @@ class UiSettings:
quicksettings_list = None quicksettings_list = None
quicksettings_names = None quicksettings_names = None
text_settings = None text_settings = None
show_all_pages = None
show_one_page = None
search_input = None
def run_settings(self, *args): def run_settings(self, *args):
changed = [] changed = []
@ -135,7 +139,7 @@ class UiSettings:
gr.Group() gr.Group()
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
current_tab.__enter__() current_tab.__enter__()
current_row = gr.Column(variant='compact') current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact')
current_row.__enter__() current_row.__enter__()
previous_section = item.section previous_section = item.section
@ -173,26 +177,43 @@ class UiSettings:
download_localization = gr.Button(value='Download localization template', elem_id="download_localization") download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row(): with gr.Row():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
with gr.Row():
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False)
self.show_one_page.click(lambda: None)
self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False)
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
def call_func_and_return_text(func, text):
def handler():
t = timer.Timer()
func()
t.record(text)
return f'{text} in {t.total:.1f}s'
return handler
unload_sd_model.click( unload_sd_model.click(
fn=sd_models.unload_model_weights, fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
inputs=[], inputs=[],
outputs=[] outputs=[self.result]
) )
reload_sd_model.click( reload_sd_model.click(
fn=sd_models.reload_model_weights, fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
inputs=[], inputs=[],
outputs=[] outputs=[self.result]
) )
request_notifications.click( request_notifications.click(
@ -241,6 +262,21 @@ class UiSettings:
outputs=[sysinfo_check_output], outputs=[sysinfo_check_output],
) )
def calculate_all_checkpoint_hash_fn(max_thread):
checkpoints_list = sd_models.checkpoints_list.values()
with ThreadPoolExecutor(max_workers=max_thread) as executor:
futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
completed = 0
for _ in as_completed(futures):
completed += 1
print(f"{completed} / {len(checkpoints_list)} ")
print("Finish calculating hash for all checkpoints")
calculate_all_checkpoint_hash.click(
fn=calculate_all_checkpoint_hash_fn,
inputs=[calculate_all_checkpoint_hash_threads],
)
self.interface = settings_interface self.interface = settings_interface
def add_quicksettings(self): def add_quicksettings(self):
@ -294,3 +330,8 @@ class UiSettings:
outputs=[self.component_dict[k] for k in component_keys], outputs=[self.component_dict[k] for k in component_keys],
queue=False, queue=False,
) )
def search(self, text):
print(text)
return [gr.update(visible=text in (comp.label or "")) for comp in self.components]

141
modules/ui_toprow.py Normal file
View File

@ -0,0 +1,141 @@
import gradio as gr
from modules import shared, ui_prompt_styles
import modules.images
from modules.ui_components import ToolButton
class Toprow:
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
prompt = None
prompt_img = None
negative_prompt = None
button_interrogate = None
button_deepbooru = None
interrupt = None
skip = None
submit = None
paste = None
clear_prompt_button = None
apply_styles = None
restore_progress_button = None
token_counter = None
token_button = None
negative_token_counter = None
negative_token_button = None
ui_styles = None
submit_box = None
def __init__(self, is_img2img, is_compact=False):
id_part = "img2img" if is_img2img else "txt2img"
self.id_part = id_part
self.is_img2img = is_img2img
self.is_compact = is_compact
if not is_compact:
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
self.create_classic_toprow()
else:
self.create_submit_box()
def create_classic_toprow(self):
self.create_prompts()
with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
self.create_submit_box()
self.create_tools_row()
self.create_styles_ui()
def create_inline_toprow_prompts(self):
if not self.is_compact:
return
self.create_prompts()
with gr.Row(elem_classes=["toprow-compact-stylerow"]):
with gr.Column(elem_classes=["toprow-compact-tools"]):
self.create_tools_row()
with gr.Column():
self.create_styles_ui()
def create_inline_toprow_image(self):
if not self.is_compact:
return
self.submit_box.render()
def create_prompts(self):
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
self.prompt_img.change(
fn=modules.images.image_data,
inputs=[self.prompt_img],
outputs=[self.prompt, self.prompt_img],
show_progress=False,
)
def create_submit_box(self):
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
self.submit_box = submit_box
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
self.skip.click(
fn=lambda: shared.state.skip(),
inputs=[],
outputs=[],
)
self.interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_tools_row(self):
with gr.Row(elem_id=f"{self.id_part}_tools"):
from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
if self.is_img2img:
self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
self.clear_prompt_button.click(
fn=lambda *x: x,
_js="confirm_clear_prompt",
inputs=[self.prompt, self.negative_prompt],
outputs=[self.prompt, self.negative_prompt],
)
def create_styles_ui(self):
self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
self.ui_styles.setup_apply_button(self.apply_styles)

164
modules/xlmr_m18.py Normal file
View File

@ -0,0 +1,164 @@
from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
class BertSeriesModelWithTransformation(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
# modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
config.bos_token_id=0
config.eos_token_id=2
config.hidden_act='gelu'
config.hidden_dropout_prob=0.1
config.hidden_size=1024
config.initializer_range=0.02
config.intermediate_size=4096
config.layer_norm_eps=1e-05
config.max_position_embeddings=514
config.num_attention_heads=16
config.num_hidden_layers=24
config.output_past=True
config.pad_token_id=1
config.position_embedding_type= "absolute"
config.type_vocab_size= 1
config.use_cache=True
config.vocab_size= 250002
config.project_dim = 1024
config.learn_encoder = False
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
# self.pooler = lambda x: x[:,0]
# self.post_init()
self.has_pre_transformation = True
if self.has_pre_transformation:
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def encode(self,c):
device = next(self.parameters()).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
return features['projection_state']
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) :
r"""
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
# # last module outputs
# sequence_output = outputs[0]
# # project every module
# sequence_output_ln = self.pre_LN(sequence_output)
# # pooler
# pooler_output = self.pooler(sequence_output_ln)
# pooler_output = self.transformation(pooler_output)
# projection_state = self.transformation(outputs.last_hidden_state)
if self.has_pre_transformation:
sequence_output2 = outputs["hidden_states"][-2]
sequence_output2 = self.pre_LN(sequence_output2)
projection_state2 = self.transformation_pre(sequence_output2)
return {
"projection_state": projection_state2,
"last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
else:
projection_state = self.transformation(outputs.last_hidden_state)
return {
"projection_state": projection_state,
"last_hidden_state": outputs.last_hidden_state,
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
# return {
# 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state,
# 'hidden_states':outputs.hidden_states,
# 'attentions':outputs.attentions,
# 'projection_state':projection_state,
# 'sequence_out': sequence_output
# }
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig

View File

@ -16,6 +16,7 @@ exclude = [
ignore = [ ignore = [
"E501", # Line too long "E501", # Line too long
"E721", # Do not compare types, use `isinstance`
"E731", # Do not assign a `lambda` expression, use a `def` "E731", # Do not assign a `lambda` expression, use a `def`
"I001", # Import block is un-sorted or un-formatted "I001", # Import block is un-sorted or un-formatted

View File

@ -27,6 +27,6 @@ timm==0.9.2
tomesd==0.1.3 tomesd==0.1.3
torch torch
torchdiffeq==0.2.3 torchdiffeq==0.2.3
torchsde==0.2.5 torchsde==0.2.6
transformers==4.30.2 transformers==4.30.2
httpx==0.24.1 httpx==0.24.1

View File

@ -124,16 +124,29 @@ document.addEventListener("DOMContentLoaded", function() {
* Add a ctrl+enter as a shortcut to start a generation * Add a ctrl+enter as a shortcut to start a generation
*/ */
document.addEventListener('keydown', function(e) { document.addEventListener('keydown', function(e) {
var handled = false; const isEnter = e.key === 'Enter' || e.keyCode === 13;
if (e.key !== undefined) { const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) { const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
}
if (handled) { if (isEnter && isModifierKey) {
var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); if (interruptButton.style.display === 'block') {
if (button) { interruptButton.click();
button.click(); const callback = (mutationList) => {
for (const mutation of mutationList) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if (interruptButton.style.display === 'none') {
generateButton.click();
observer.disconnect();
}
}
}
};
const observer = new MutationObserver(callback);
observer.observe(interruptButton, {attributes: true});
} else {
generateButton.click();
} }
e.preventDefault(); e.preventDefault();
} }

View File

@ -29,7 +29,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn") upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow(): with FormRow():

View File

@ -5,11 +5,17 @@ import shlex
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import sd_samplers, errors from modules import sd_samplers, errors, sd_models
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from modules.shared import state from modules.shared import state
def process_model_tag(tag):
info = sd_models.get_closet_checkpoint_match(tag)
assert info is not None, f'Unknown checkpoint: {tag}'
return info.name
def process_string_tag(tag): def process_string_tag(tag):
return tag return tag
@ -27,7 +33,7 @@ def process_boolean_tag(tag):
prompt_tags = { prompt_tags = {
"sd_model": None, "sd_model": process_model_tag,
"outpath_samples": process_string_tag, "outpath_samples": process_string_tag,
"outpath_grids": process_string_tag, "outpath_grids": process_string_tag,
"prompt_for_display": process_string_tag, "prompt_for_display": process_string_tag,
@ -108,6 +114,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
@ -118,9 +125,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed. # be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True p.do_not_save_grid = True
@ -156,7 +163,22 @@ class Script(scripts.Script):
copy_p = copy.copy(p) copy_p = copy.copy(p)
for k, v in args.items(): for k, v in args.items():
setattr(copy_p, k, v) if k == "sd_model":
copy_p.override_settings['sd_model_checkpoint'] = v
else:
setattr(copy_p, k, v)
if args.get("prompt") and p.prompt:
if prompt_position == "start":
copy_p.prompt = args.get("prompt") + " " + p.prompt
else:
copy_p.prompt = p.prompt + " " + args.get("prompt")
if args.get("negative_prompt") and p.negative_prompt:
if prompt_position == "start":
copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
else:
copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images

View File

@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str):
class AxisOption: class AxisOption:
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):
self.label = label self.label = label
self.type = type self.type = type
self.apply = apply self.apply = apply
self.format_value = format_value self.format_value = format_value
self.confirm = confirm self.confirm = confirm
self.cost = cost self.cost = cost
self.prepare = prepare
self.choices = choices self.choices = choices
@ -536,6 +537,8 @@ class Script(scripts.Script):
if opt.choices is not None and not csv_mode: if opt.choices is not None and not csv_mode:
valslist = vals_dropdown valslist = vals_dropdown
elif opt.prepare is not None:
valslist = opt.prepare(vals)
else: else:
valslist = csv_string_to_list_strip(vals) valslist = csv_string_to_list_strip(vals)
@ -773,6 +776,8 @@ class Script(scripts.Script):
# TODO: See previous comment about intentional data misalignment. # TODO: See previous comment about intentional data misalignment.
adj_g = g-1 if g > 0 else g adj_g = g-1 if g > 0 else g
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid
break
if not include_sub_grids: if not include_sub_grids:
# Done with sub-grids, drop all related information: # Done with sub-grids, drop all related information:

View File

@ -83,8 +83,10 @@ div.compact{
white-space: nowrap; white-space: nowrap;
} }
.gradio-dropdown ul.options li.item { @media (pointer:fine) {
padding: 0.05em 0; .gradio-dropdown ul.options li.item {
padding: 0.05em 0;
}
} }
.gradio-dropdown ul.options li.item.selected { .gradio-dropdown ul.options li.item.selected {
@ -202,6 +204,11 @@ div.block.gradio-accordion {
padding: 8px 8px; padding: 8px 8px;
} }
input[type="checkbox"].input-accordion-checkbox{
vertical-align: sub;
margin-right: 0.5em;
}
/* txt2img/img2img specific */ /* txt2img/img2img specific */
@ -289,6 +296,13 @@ div.block.gradio-accordion {
min-height: 4.5em; min-height: 4.5em;
} }
#txt2img_generate, #img2img_generate {
min-height: 4.5em;
}
.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {
min-height: 3em;
}
@media screen and (min-width: 2500px) { @media screen and (min-width: 2500px) {
#txt2img_gallery, #img2img_gallery { #txt2img_gallery, #img2img_gallery {
min-height: 768px; min-height: 768px;
@ -396,6 +410,15 @@ div#extras_scale_to_tab div.form{
min-width: 0.5em; min-width: 0.5em;
} }
div.toprow-compact-stylerow{
margin: 0.5em 0;
}
div.toprow-compact-tools{
min-width: fit-content !important;
max-width: fit-content;
}
/* settings */ /* settings */
#quicksettings { #quicksettings {
align-items: end; align-items: end;
@ -421,6 +444,7 @@ div#extras_scale_to_tab div.form{
#settings > div{ #settings > div{
border: none; border: none;
margin-left: 10em; margin-left: 10em;
padding: 0 var(--spacing-xl);
} }
#settings > div.tab-nav{ #settings > div.tab-nav{
@ -435,6 +459,16 @@ div#extras_scale_to_tab div.form{
border: none; border: none;
text-align: left; text-align: left;
white-space: initial; white-space: initial;
padding: 4px;
}
#settings > div.tab-nav .settings-category{
display: block;
margin: 1em 0 0.25em 0;
font-weight: bold;
text-decoration: underline;
cursor: default;
user-select: none;
} }
#settings_result{ #settings_result{
@ -516,7 +550,8 @@ table.popup-table .link{
height: 20px; height: 20px;
background: #b4c0cc; background: #b4c0cc;
border-radius: 3px !important; border-radius: 3px !important;
top: -20px; top: -14px;
left: 0px;
width: 100%; width: 100%;
} }
@ -581,7 +616,6 @@ table.popup-table .link{
width: 100%; width: 100%;
height: 100%; height: 100%;
overflow: auto; overflow: auto;
background-color: rgba(20, 20, 20, 0.95);
} }
.global-popup *{ .global-popup *{
@ -590,9 +624,6 @@ table.popup-table .link{
.global-popup-close:before { .global-popup-close:before {
content: "×"; content: "×";
}
.global-popup-close{
position: fixed; position: fixed;
right: 0.25em; right: 0.25em;
top: 0; top: 0;
@ -601,10 +632,20 @@ table.popup-table .link{
font-size: 32pt; font-size: 32pt;
} }
.global-popup-close{
position: fixed;
left: 0;
top: 0;
width: 100%;
height: 100%;
background-color: rgba(20, 20, 20, 0.95);
}
.global-popup-inner{ .global-popup-inner{
display: inline-block; display: inline-block;
margin: auto; margin: auto;
padding: 2em; padding: 2em;
z-index: 1001;
} }
/* fullpage image viewer */ /* fullpage image viewer */
@ -808,6 +849,18 @@ footer {
/* extra networks UI */ /* extra networks UI */
.extra-page > div.gap{
gap: 0;
}
.extra-page-prompts{
margin-bottom: 0;
}
.extra-page-prompts.extra-page-prompts-active{
margin-bottom: 1em;
}
.extra-network-cards{ .extra-network-cards{
height: calc(100vh - 24rem); height: calc(100vh - 24rem);
overflow: clip scroll; overflow: clip scroll;

View File

@ -1,6 +1,11 @@
@echo off @echo off
if exist webui.settings.bat (
call webui.settings.bat
)
if not defined PYTHON (set PYTHON=python) if not defined PYTHON (set PYTHON=python)
if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
set SD_WEBUI_RESTART=tmp/restart set SD_WEBUI_RESTART=tmp/restart

View File

@ -74,7 +74,7 @@ def webui():
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch: if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
auto_launch_browser = True auto_launch_browser = True
elif shared.opts.auto_launch_browser == "Local": elif shared.opts.auto_launch_browser == "Local":
auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok, cmd_opts.server_name]) auto_launch_browser = not cmd_opts.webui_is_non_local
app, local_url, share_url = shared.demo.launch( app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share, share=cmd_opts.share,

View File

@ -4,12 +4,6 @@
# change the variables in webui-user.sh instead # # change the variables in webui-user.sh instead #
################################################# #################################################
use_venv=1
if [[ $venv_dir == "-" ]]; then
use_venv=0
fi
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
@ -28,6 +22,12 @@ then
source "$SCRIPT_DIR"/webui-user.sh source "$SCRIPT_DIR"/webui-user.sh
fi fi
# If $venv_dir is "-", then disable venv support
use_venv=1
if [[ $venv_dir == "-" ]]; then
use_venv=0
fi
# Set defaults # Set defaults
# Install directory without trailing slash # Install directory without trailing slash
if [[ -z "${install_dir}" ]] if [[ -z "${install_dir}" ]]
@ -51,6 +51,8 @@ fi
if [[ -z "${GIT}" ]] if [[ -z "${GIT}" ]]
then then
export GIT="git" export GIT="git"
else
export GIT_PYTHON_GIT_EXECUTABLE="${GIT}"
fi fi
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
@ -87,7 +89,7 @@ delimiter="################################################################"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n" printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m" printf "\e[1m\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\e[0m"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
# Do not run as root # Do not run as root
@ -141,9 +143,8 @@ case "$gpu_info" in
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;; ;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
# Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) # Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
# so switch to nightly rocm5.6 without explicit versions this time
;; ;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
@ -222,7 +223,7 @@ fi
# Try using TCMalloc on Linux # Try using TCMalloc on Linux
prepare_tcmalloc() { prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" TCMALLOC="$(PATH=/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}" echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}" export LD_PRELOAD="${TCMALLOC}"