apply_token_merging
This commit is contained in:
parent
2a7fb1be24
commit
bde779a526
@ -1,8 +1,6 @@
|
||||
# Taken from https://github.com/comfyanonymous/ComfyUI
|
||||
# This file is only for reference, and not used in the backend or runtime.
|
||||
|
||||
|
||||
#Taken from: https://github.com/dbolya/tomesd
|
||||
# 1st edit: https://github.com/dbolya/tomesd
|
||||
# 2nd edit: https://github.com/comfyanonymous/ComfyUI
|
||||
# 3rd edit: Forge official
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Callable
|
||||
@ -148,34 +146,19 @@ def get_functions(x, ratio, original_shape):
|
||||
return nothing, nothing
|
||||
|
||||
|
||||
|
||||
class TomePatchModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
class TomePatcher:
|
||||
def __init__(self):
|
||||
self.u = None
|
||||
|
||||
def patch(self, model, ratio):
|
||||
self.u = None
|
||||
def tomesd_m(q, k, v, extra_options):
|
||||
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||
#however from my basic testing it seems that using q instead gives better results
|
||||
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
|
||||
return m(q), k, v
|
||||
|
||||
def tomesd_u(n, extra_options):
|
||||
return self.u(n)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_attn1_patch(tomesd_m)
|
||||
m.set_model_attn1_output_patch(tomesd_u)
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TomePatchModel": TomePatchModel,
|
||||
}
|
||||
return m
|
||||
|
@ -33,6 +33,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||
|
||||
from einops import repeat, rearrange
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
from modules.sd_models import apply_token_merging
|
||||
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
@ -747,13 +748,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if k == 'sd_vae':
|
||||
sd_vae.reload_vae_weights()
|
||||
|
||||
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||
|
||||
res = process_images_inner(p)
|
||||
|
||||
finally:
|
||||
sd_models.apply_token_merging(p.sd_model, 0)
|
||||
|
||||
# restore opts to original state
|
||||
if p.override_settings_restore_afterwards:
|
||||
for k, v in stored_opts.items():
|
||||
@ -1259,6 +1256,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
x = self.rng.next()
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=x,
|
||||
@ -1366,12 +1365,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
with devices.autocast():
|
||||
self.calculate_hr_conds()
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=samples,
|
||||
@ -1385,8 +1384,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.sampler = None
|
||||
devices.torch_gc()
|
||||
|
||||
@ -1687,6 +1684,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
x *= self.initial_noise_multiplier
|
||||
|
||||
self.sd_model.forge_objects = self.sd_model.forge_objects_after_applying_lora.shallow_copy()
|
||||
apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.process_before_every_sampling(self,
|
||||
x=self.init_latent,
|
||||
|
@ -633,8 +633,16 @@ def unload_model_weights(sd_model=None, info=None):
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
if token_merging_ratio > 0:
|
||||
print('Token merging is under construction now and the setting will not take effect.')
|
||||
if token_merging_ratio <= 0:
|
||||
return
|
||||
|
||||
print(f'token_merging_ratio = {token_merging_ratio}')
|
||||
|
||||
from ldm_patched.contrib.external_tomesd import TomePatcher
|
||||
|
||||
sd_model.forge_objects.unet = TomePatcher().patch(
|
||||
model=sd_model.forge_objects.unet,
|
||||
ratio=token_merging_ratio
|
||||
)
|
||||
|
||||
# TODO: rework using new UNet patcher system
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user