From e14b586d0494d6c5cc3cbc45b5fa00c03d052443 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 14 May 2023 12:42:44 +0800 Subject: [PATCH 1/6] Add Tiny AE live preview --- modules/sd_samplers_common.py | 21 ++++++---- modules/sd_vae_taesd.py | 76 +++++++++++++++++++++++++++++++++++ modules/shared.py | 2 +- webui.py | 11 +++++ 4 files changed, 101 insertions(+), 9 deletions(-) create mode 100644 modules/sd_vae_taesd.py diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bc074238..d3dc130c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx +from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd from modules.shared import opts, state import modules.shared as shared @@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} +approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3} def single_sample_to_image(sample, approximation=None): if approximation is None: approximation = approximation_indexes.get(opts.show_progress_type, 0) - if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + if approximation == 1: + x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) + x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1) else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + if approximation == 3: + x_sample = sd_vae_approx.cheap_approximation(sample) + elif approximation == 2: + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + else: + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..ccc97959 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,76 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 2 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def decode(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('Tiny AE mdoel not found') + + return sd_vae_taesd.decoder diff --git a/modules/shared.py b/modules/shared.py index 4631965b..6760a900 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), + "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") })) diff --git a/webui.py b/webui.py index 727ebd31..0a928434 100644 --- a/webui.py +++ b/webui.py @@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check. """.strip()) +def check_taesd(): + from modules.paths_internal import models_path + + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") + if not os.path.exists(model_path): + print('download taesd model') + torch.hub.download_url_to_file(model_url, os.path.dirname(model_path)) + + def initialize(): fix_asyncio_event_loop_policy() check_versions() + check_taesd() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) From bd9b9d425a355e151b43047a5df5fcead2fcdc52 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Sun, 14 May 2023 13:19:11 +0800 Subject: [PATCH 2/6] Add live preview mode check --- modules/sd_samplers_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index d3dc130c..b1e8a780 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -26,8 +26,8 @@ approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap" def single_sample_to_image(sample, approximation=None): - if approximation is None: - approximation = approximation_indexes.get(opts.show_progress_type, 0) + if approximation is None or approximation not in approximation_indexes.keys(): + approximation = approximation_indexes.get(opts.show_progress_type, 1) if approximation == 1: x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() From 742da3193290f5692901c4c614c98bec291163f2 Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Mon, 15 May 2023 03:04:34 +0800 Subject: [PATCH 3/6] Minor changes --- modules/sd_vae_taesd.py | 2 +- webui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index ccc97959..927a7298 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -71,6 +71,6 @@ def decode(): sd_vae_taesd.eval() sd_vae_taesd.to(devices.device, devices.dtype) else: - raise FileNotFoundError('Tiny AE mdoel not found') + raise FileNotFoundError('Tiny AE model not found') return sd_vae_taesd.decoder diff --git a/webui.py b/webui.py index 0a928434..0d0816bc 100644 --- a/webui.py +++ b/webui.py @@ -151,7 +151,7 @@ def check_taesd(): model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") if not os.path.exists(model_path): print('download taesd model') - torch.hub.download_url_to_file(model_url, os.path.dirname(model_path)) + torch.hub.download_url_to_file(model_url, model_path) def initialize(): From 4fb2cc0f060d1f63e0e62e38d37e983745ce3fda Mon Sep 17 00:00:00 2001 From: Sakura-Luna <53183413+Sakura-Luna@users.noreply.github.com> Date: Wed, 17 May 2023 00:32:32 +0800 Subject: [PATCH 4/6] Minor change --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 0d0816bc..0aa03ea8 100644 --- a/webui.py +++ b/webui.py @@ -150,7 +150,7 @@ def check_taesd(): model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") if not os.path.exists(model_path): - print('download taesd model') + print('From taesd repo download decoder model') torch.hub.download_url_to_file(model_url, model_path) From b217ebc49000b41baab3094dbc8caaf33eaf5579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 17 May 2023 08:41:21 +0300 Subject: [PATCH 5/6] add credits --- README.md | 1 + html/licenses.html | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/README.md b/README.md index 67a1a83a..c1e193c0 100644 --- a/README.md +++ b/README.md @@ -158,5 +158,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC +- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/html/licenses.html b/html/licenses.html index bc995aa0..ef6f2c0a 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +

TAESD

+Tiny AutoEncoder for Stable Diffusion option for live previews +
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
 
\ No newline at end of file From 56a2672831751480f94a018f861f0143a8234ae8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 17 May 2023 09:24:01 +0300 Subject: [PATCH 6/6] return live preview defaults to how they were only download TAESD model when it's needed return calculations in single_sample_to_image to just if/elif/elif blocks keep taesd model in its own directory --- modules/sd_samplers_common.py | 29 +++++++++++++++-------------- modules/sd_vae_taesd.py | 18 +++++++++++++++--- modules/shared.py | 2 +- webui.py | 11 ----------- 4 files changed, 31 insertions(+), 29 deletions(-) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b1e8a780..20a9af20 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -22,28 +22,29 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3} +approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} def single_sample_to_image(sample, approximation=None): - if approximation is None or approximation not in approximation_indexes.keys(): - approximation = approximation_indexes.get(opts.show_progress_type, 1) - if approximation == 1: - x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) - x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1) + if approximation is None: + approximation = approximation_indexes.get(opts.show_progress_type, 0) + + if approximation == 2: + x_sample = sd_vae_approx.cheap_approximation(sample) + elif approximation == 1: + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + elif approximation == 3: + x_sample = sd_vae_taesd.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample) # returns value in [-2, 2] + x_sample = x_sample * 0.5 else: - if approximation == 3: - x_sample = sd_vae_approx.cheap_approximation(sample) - elif approximation == 2: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() - else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) + return Image.fromarray(x_sample) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 927a7298..d23812ef 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -61,16 +61,28 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def decode(): +def download_model(model_path): + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading TAESD decoder to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + +def model(): global sd_vae_taesd if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth") + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + download_model(model_path) + if os.path.exists(model_path): sd_vae_taesd = TAESD(model_path) sd_vae_taesd.eval() sd_vae_taesd.to(devices.device, devices.dtype) else: - raise FileNotFoundError('Tiny AE model not found') + raise FileNotFoundError('TAESD model not found') return sd_vae_taesd.decoder diff --git a/modules/shared.py b/modules/shared.py index 6760a900..96036d38 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") })) diff --git a/webui.py b/webui.py index 0aa03ea8..727ebd31 100644 --- a/webui.py +++ b/webui.py @@ -144,21 +144,10 @@ Use --skip-version-check commandline argument to disable this check. """.strip()) -def check_taesd(): - from modules.paths_internal import models_path - - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth") - if not os.path.exists(model_path): - print('From taesd repo download decoder model') - torch.hub.download_url_to_file(model_url, model_path) - - def initialize(): fix_asyncio_event_loop_policy() check_versions() - check_taesd() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir)