From 9a8678f61eff172811498a682c171399b7216e12 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Tue, 29 Nov 2022 11:11:29 +0800 Subject: [PATCH 01/10] Support changing checkpoint and vae through override_settings --- modules/processing.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index edceb532..a5c72e3d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,6 +20,8 @@ import modules.shared as shared import modules.face_restoration import modules.images as images import modules.styles +import modules.sd_models as sd_models +import modules.sd_vae as sd_vae import logging @@ -424,8 +426,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): - setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible - if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not + setattr(opts, k, v) + if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet + if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model + if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE res = process_images_inner(p) @@ -433,6 +437,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for k, v in stored_opts.items(): setattr(opts, k, v) if k == 'sd_hypernetwork': shared.reload_hypernetworks() + if k == 'sd_model_checkpoint': sd_models.reload_model_weights() + if k == 'sd_vae': sd_vae.reload_vae_weights() return res From 2eb5f103ab1b41477440cc391165ea7ef5f7f959 Mon Sep 17 00:00:00 2001 From: apolinario Date: Mon, 5 Dec 2022 16:30:15 +0100 Subject: [PATCH 02/10] Fix WebUI not working inside of iframes --- script.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/script.js b/script.js index 8b3b67e3..d9424de9 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,5 @@ function gradioApp(){ - return document.getElementsByTagName('gradio-app')[0].shadowRoot; + return document } function get_uiCurrentTab() { @@ -82,4 +82,4 @@ function uiElementIsVisible(el) { } } return isVisible; -} \ No newline at end of file +} From 1075819b16ef328805dd946acaffd43efa2eb444 Mon Sep 17 00:00:00 2001 From: apolinario Date: Tue, 6 Dec 2022 15:13:41 +0100 Subject: [PATCH 03/10] Use shadowRoot if inside of an iframe and don't use it if outside This makes sure it will work everywhere --- script.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script.js b/script.js index d9424de9..49b6c557 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,5 @@ function gradioApp(){ - return document + return !!document.getElementsByTagName('gradio-app')[0].shadowRoot ? document.getElementsByTagName('gradio-app')[0].shadowRoot : document } function get_uiCurrentTab() { From 8eb638cdd3d08ad6e9373569fd81d0a6e8a63f16 Mon Sep 17 00:00:00 2001 From: apolinario Date: Tue, 6 Dec 2022 15:14:22 +0100 Subject: [PATCH 04/10] style change --- script.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script.js b/script.js index 49b6c557..1c8fcbf1 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,5 @@ function gradioApp(){ - return !!document.getElementsByTagName('gradio-app')[0].shadowRoot ? document.getElementsByTagName('gradio-app')[0].shadowRoot : document + return !!document.getElementsByTagName('gradio-app')[0].shadowRoot ? document.getElementsByTagName('gradio-app')[0].shadowRoot : document; } function get_uiCurrentTab() { From 37139d8aac10bd13758f52e3d361f3d017c4ad46 Mon Sep 17 00:00:00 2001 From: apolinario Date: Sat, 10 Dec 2022 12:51:40 +0100 Subject: [PATCH 05/10] No code repeat --- script.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/script.js b/script.js index 1c8fcbf1..9748ec90 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,6 @@ -function gradioApp(){ - return !!document.getElementsByTagName('gradio-app')[0].shadowRoot ? document.getElementsByTagName('gradio-app')[0].shadowRoot : document; +function gradioApp() { + const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot + return !!gradioShadowRoot ? gradioShadowRoot : document; } function get_uiCurrentTab() { From 991e2dcee9d6baa66b5c0b1969c4c07407be933a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 14:54:02 +0300 Subject: [PATCH 06/10] remove NSFW filter and its dependency; if you still want it, find it in the extensions section --- modules/processing.py | 7 +++---- modules/safety.py | 42 --------------------------------------- modules/scripts.py | 20 +++++++++++++++++++ modules/shared.py | 1 - requirements.txt | 1 - requirements_versions.txt | 1 - 6 files changed, 23 insertions(+), 49 deletions(-) delete mode 100644 modules/safety.py diff --git a/modules/processing.py b/modules/processing.py index 81400d14..056c9322 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -571,9 +571,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - if opts.filter_nsfw: - import modules.safety as safety - x_samples_ddim = modules.safety.censor_batch(x_samples_ddim) + if p.scripts is not None: + p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) diff --git a/modules/safety.py b/modules/safety.py deleted file mode 100644 index cff4b278..00000000 --- a/modules/safety.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor -from PIL import Image - -import modules.shared as shared - -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - -def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - -# check and replace nsfw content -def check_safety(x_image): - global safety_feature_extractor, safety_checker - - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - - return x_checked_image, has_nsfw_concept - - -def censor_batch(x): - x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) - x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) - - return x diff --git a/modules/scripts.py b/modules/scripts.py index b934d881..23ca195d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -88,6 +88,17 @@ class Script: pass + def postprocess_batch(self, p, *args, **kwargs): + """ + Same as process_batch(), but called for every batch after it has been generated. + + **kwargs will have same items as process_batch, and also: + - batch_number - index of current batch, from 0 to number of batches-1 + - images - torch tensor with all generated images, with values ranging from 0 to 1; + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -347,6 +358,15 @@ class ScriptRunner: print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_batch(self, p, images, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_batch(p, *script_args, images=images, **kwargs) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/shared.py b/modules/shared.py index 44922c91..272267c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -367,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) diff --git a/requirements.txt b/requirements.txt index 05818aa6..678acb4d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ accelerate basicsr -diffusers fairscale==0.4.4 fonts font-roboto diff --git a/requirements_versions.txt b/requirements_versions.txt index 035fa82f..185cd066 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,4 @@ transformers==4.19.2 -diffusers==0.3.0 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 From 713c48ddd7f296fe064cf58af7baa31aa5fcffb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 15:05:22 +0300 Subject: [PATCH 07/10] add an 'installed' tag to extensions --- modules/ui_extensions.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b487ac25..1434f25f 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -206,12 +206,13 @@ def refresh_available_extensions_from_data(hide_tags): if url is None: continue + existing = installed_extension_urls.get(normalize_git_url(url), None) + extension_tags = extension_tags + ["installed"] if existing else extension_tags + if len([x for x in extension_tags if x in tags_to_hide]) > 0: hidden += 1 continue - existing = installed_extension_urls.get(normalize_git_url(url), None) - install_code = f"""""" tags_text = ", ".join([f"{x}" for x in extension_tags]) @@ -222,7 +223,11 @@ def refresh_available_extensions_from_data(hide_tags): {html.escape(description)} {install_code} - """ + + """ + + for tag in [x for x in extension_tags if x not in tags]: + tags[tag] = tag code += """ @@ -272,7 +277,7 @@ def create_ui(): install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) with gr.Row(): - hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) + hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) install_result = gr.HTML() available_extensions_table = gr.HTML() From 718dbe5e8282c8992afc4a67fe270b342013d282 Mon Sep 17 00:00:00 2001 From: Bwin4L Date: Sat, 10 Dec 2022 14:51:11 +0100 Subject: [PATCH 08/10] Fix token counter color on dark theme --- .../prompt-bracket-checker/javascript/prompt-bracket-checker.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 3f3bebcd..eccfb0f9 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -88,7 +88,7 @@ function checkBrackets(evt) { if(counterElt.title != '') { counterElt.style = 'color: #FF5555;'; } else { - counterElt.style = 'color: #000;'; + counterElt.style = ''; } } From 6df316c881b533731faa77494ea01533e35f8dc7 Mon Sep 17 00:00:00 2001 From: wywywywy Date: Sat, 10 Dec 2022 13:54:29 +0000 Subject: [PATCH 09/10] LDSR cache / optimization / opt_channelslast --- extensions-builtin/LDSR/ldsr_model_arch.py | 38 +++++++++++++------ extensions-builtin/LDSR/scripts/ldsr_model.py | 1 + 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index a87d1ef9..9ec4e67e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -11,25 +11,41 @@ from omegaconf import OmegaConf from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap +from modules import shared, sd_hijack warnings.filterwarnings("ignore", category=UserWarning) +cached_ldsr_model: torch.nn.Module = None + # Create LDSR Class class LDSR: def load_model_from_config(self, half_attention): - print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] - config = OmegaConf.load(self.yamlPath) - config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" - model = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model.cuda() - if half_attention: - model = model.half() + global cached_ldsr_model + + if shared.opts.ldsr_cached and cached_ldsr_model is not None: + print(f"Loading model from cache") + model: torch.nn.Module = cached_ldsr_model + else: + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" + model: torch.nn.Module = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model = model.to(shared.device) + if half_attention: + model = model.half() + if shared.cmd_opts.opt_channelslast: + model = model.to(memory_format=torch.channels_last) + + sd_hijack.model_hijack.hijack(model) # apply optimization + model.eval() + + if shared.opts.ldsr_cached: + cached_ldsr_model = model - model.eval() return {"model": model} def __init__(self, model_path, yaml_path): diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 5c96037d..29d5f94e 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -59,6 +59,7 @@ def on_ui_settings(): import gradio as gr shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) script_callbacks.on_ui_settings(on_ui_settings) From 1581d5a1674fbbeaf047b79f3a138781d6322e6e Mon Sep 17 00:00:00 2001 From: wywywywy Date: Sat, 10 Dec 2022 14:07:27 +0000 Subject: [PATCH 10/10] Made device agnostic --- extensions-builtin/LDSR/ldsr_model_arch.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 9ec4e67e..8b048ae0 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -110,7 +110,8 @@ class LDSR: down_sample_method = 'Lanczos' gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() im_og = image width_og, height_og = im_og.size @@ -147,7 +148,9 @@ class LDSR: del model gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() + return a @@ -162,7 +165,7 @@ def get_cond(selected_path): c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. - c = c.to(torch.device("cuda")) + c = c.to(shared.device) example["LR_image"] = c example["image"] = c_up