From 861db783c7acfcb93cf0b5191db3d50f9a9bc531 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 11 Oct 2022 05:13:17 -0400 Subject: [PATCH] Use apply_hypernetwork function --- modules/sd_hijack_optimizations.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f006427f..79405525 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) * self.scale - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) * self.scale - v = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) * self.scale + v = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) r = einsum_op(q, k, v)