From c96e4750d895a47290dc7f96e030197069c75fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:07:09 +0300 Subject: [PATCH] SD VAE rework 2 - the setting for preferring opts.sd_vae has been inverted and reworded - resolve_vae function made easier to read and now returns an object rather than a tuple - if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata - changing VAE in user metadata for currently loaded model immediately applies the selection --- modules/sd_models.py | 2 +- modules/sd_vae.py | 71 ++++++++++++++----- modules/shared.py | 6 +- ...xtra_networks_checkpoints_user_metadata.py | 8 ++- webui.py | 2 +- 5 files changed, 69 insertions(+), 20 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..d65735e3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -356,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0bd5e19b..38bcb840 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,7 @@ import os import collections +from dataclasses import dataclass + from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True + def tuple(self): + return self.vae, self.source + + +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: + if shared.opts.sd_vae == "None": + return VaeResolution() + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return VaeResolution(vae_from_options, 'specified in settings') + + if not is_automatic(): + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: metadata = extra_networks.get_user_metadata(checkpoint_file) vae_metadata = metadata.get("vae", None) if vae_metadata is not None and vae_metadata != "Automatic": if vae_metadata == "None": - return None, None + return VaeResolution() vae_from_metadata = vae_dict.get(vae_metadata, None) if vae_from_metadata is not None: - return vae_from_metadata, "from user metadata" + return VaeResolution(vae_from_metadata, "from user metadata") - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + return VaeResolution(resolved=False) + +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') - if shared.opts.sd_vae == "None": - return None, None + return VaeResolution(resolved=False) - vae_from_options = vae_dict.get(shared.opts.sd_vae, None) - if vae_from_options is not None: - return vae_from_options, 'specified in settings' - if not is_automatic: - print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') - return None, None + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() + + return res def load_vae_dict(filename, map_location): @@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" diff --git a/modules/shared.py b/modules/shared.py index 078e8135..da53f2d9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -479,7 +479,7 @@ For img2img, VAE is used to process user's input image before the sampling, and """), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), - "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), @@ -733,6 +733,10 @@ class Options: with open(filename, "r", encoding="utf8") as file: self.data = json.load(file) + # 1.6.0 VAE defaults + if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: + self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') + # 1.1.1 quicksettings list migration if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py index 2c69aab8..25df0a80 100644 --- a/modules/ui_extra_networks_checkpoints_user_metadata.py +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_extra_networks_user_metadata, sd_vae +from modules import ui_extra_networks_user_metadata, sd_vae, shared from modules.ui_common import create_refresh_button @@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.write_user_metadata(name, user_metadata) + def update_vae(self, name): + if name == shared.sd_model.sd_checkpoint_info.name_for_extra: + sd_vae.reload_vae_weights() + def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) values = super().put_values_into_components(name) @@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE ] self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) + self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) + diff --git a/webui.py b/webui.py index 1803ea8a..a5b11575 100644 --- a/webui.py +++ b/webui.py @@ -211,7 +211,7 @@ def configure_sigint_handler(): def configure_opts_onchange(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)