From bde779a5267fffca875e8b84d6879aea9a91db72 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 23 Feb 2024 15:43:27 -0800 Subject: [PATCH] apply_token_merging --- ldm_patched/contrib/external_tomesd.py | 33 +++++++------------------- modules/processing.py | 15 ++++++------ modules/sd_models.py | 14 ++++++++--- 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/ldm_patched/contrib/external_tomesd.py b/ldm_patched/contrib/external_tomesd.py index 111d3246..1fceef8a 100644 --- a/ldm_patched/contrib/external_tomesd.py +++ b/ldm_patched/contrib/external_tomesd.py @@ -1,8 +1,6 @@ -# Taken from https://github.com/comfyanonymous/ComfyUI -# This file is only for reference, and not used in the backend or runtime. - - -#Taken from: https://github.com/dbolya/tomesd +# 1st edit: https://github.com/dbolya/tomesd +# 2nd edit: https://github.com/comfyanonymous/ComfyUI +# 3rd edit: Forge official import torch from typing import Tuple, Callable @@ -148,34 +146,19 @@ def get_functions(x, ratio, original_shape): return nothing, nothing - -class TomePatchModel: - @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "_for_testing" +class TomePatcher: + def __init__(self): + self.u = None def patch(self, model, ratio): - self.u = None def tomesd_m(q, k, v, extra_options): - #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q - #however from my basic testing it seems that using q instead gives better results m, self.u = get_functions(q, ratio, extra_options["original_shape"]) return m(q), k, v + def tomesd_u(n, extra_options): return self.u(n) m = model.clone() m.set_model_attn1_patch(tomesd_m) m.set_model_attn1_output_patch(tomesd_u) - return (m, ) - - -NODE_CLASS_MAPPINGS = { - "TomePatchModel": TomePatchModel, -} + return m diff --git a/modules/processing.py b/modules/processing.py index 3919ac55..a23607e7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -33,6 +33,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +from modules.sd_models import apply_token_merging # some of those options should not be changed at all because they would break the model, so I removed them from options. @@ -747,13 +748,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() - sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio()) - res = process_images_inner(p) finally: - sd_models.apply_token_merging(p.sd_model, 0) - # restore opts to original state if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): @@ -1259,6 +1256,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = self.rng.next() self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy() + apply_token_merging(self.sd_model, self.get_token_merging_ratio()) + if self.scripts is not None: self.scripts.process_before_every_sampling(self, x=x, @@ -1366,12 +1365,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): self.calculate_hr_conds() - sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) - if self.scripts is not None: self.scripts.before_hr(self) self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy() + apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) + if self.scripts is not None: self.scripts.process_before_every_sampling(self, x=samples, @@ -1385,8 +1384,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) - sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) - self.sampler = None devices.torch_gc() @@ -1687,6 +1684,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): x *= self.initial_noise_multiplier self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy() + apply_token_merging(self.sd_model, self.get_token_merging_ratio()) + if self.scripts is not None: self.scripts.process_before_every_sampling(self, x=self.init_latent, diff --git a/modules/sd_models.py b/modules/sd_models.py index 83d9cb80..185062db 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -633,8 +633,16 @@ def unload_model_weights(sd_model=None, info=None): def apply_token_merging(sd_model, token_merging_ratio): - if token_merging_ratio > 0: - print('Token merging is under construction now and the setting will not take effect.') + if token_merging_ratio <= 0: + return + + print(f'token_merging_ratio = {token_merging_ratio}') + + from ldm_patched.contrib.external_tomesd import TomePatcher + + sd_model.forge_objects.unet = TomePatcher().patch( + model=sd_model.forge_objects.unet, + ratio=token_merging_ratio + ) - # TODO: rework using new UNet patcher system return