apply_token_merging

This commit is contained in:
lllyasviel 2024-02-23 15:43:27 -08:00
parent 2a7fb1be24
commit bde779a526
3 changed files with 26 additions and 36 deletions

View File

@ -1,8 +1,6 @@
# Taken from https://github.com/comfyanonymous/ComfyUI
# This file is only for reference, and not used in the backend or runtime.
#Taken from: https://github.com/dbolya/tomesd
# 1st edit: https://github.com/dbolya/tomesd
# 2nd edit: https://github.com/comfyanonymous/ComfyUI
# 3rd edit: Forge official
import torch
from typing import Tuple, Callable
@ -148,34 +146,19 @@ def get_functions(x, ratio, original_shape):
return nothing, nothing
class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
class TomePatcher:
def __init__(self):
self.u = None
def patch(self, model, ratio):
self.u = None
def tomesd_m(q, k, v, extra_options):
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
return self.u(n)
m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return (m, )
NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel,
}
return m

View File

@ -33,6 +33,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
from modules.sd_models import apply_token_merging
# some of those options should not be changed at all because they would break the model, so I removed them from options.
@ -747,13 +748,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
res = process_images_inner(p)
finally:
sd_models.apply_token_merging(p.sd_model, 0)
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@ -1259,6 +1256,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = self.rng.next()
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=x,
@ -1366,12 +1365,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
self.calculate_hr_conds()
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
if self.scripts is not None:
self.scripts.before_hr(self)
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=samples,
@ -1385,8 +1384,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
@ -1687,6 +1684,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
x *= self.initial_noise_multiplier
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
if self.scripts is not None:
self.scripts.process_before_every_sampling(self,
x=self.init_latent,

View File

@ -633,8 +633,16 @@ def unload_model_weights(sd_model=None, info=None):
def apply_token_merging(sd_model, token_merging_ratio):
if token_merging_ratio > 0:
print('Token merging is under construction now and the setting will not take effect.')
# TODO: rework using new UNet patcher system
if token_merging_ratio <= 0:
return
print(f'token_merging_ratio = {token_merging_ratio}')
from ldm_patched.contrib.external_tomesd import TomePatcher
sd_model.forge_objects.unet = TomePatcher().patch(
model=sd_model.forge_objects.unet,
ratio=token_merging_ratio
)
return