From 29be1da7cf2b5dccfc70fbdd33eb35c56a31ffb7 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 8 Mar 2024 00:50:29 -0800 Subject: [PATCH] pass options to cross attention class --- ldm_patched/ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm_patched/ldm/modules/attention.py b/ldm_patched/ldm/modules/attention.py index 6bfa64f9..13183e5f 100644 --- a/ldm_patched/ldm/modules/attention.py +++ b/ldm_patched/ldm/modules/attention.py @@ -385,7 +385,7 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -504,7 +504,7 @@ class BasicTransformerBlock(nn.Module): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -544,7 +544,7 @@ class BasicTransformerBlock(nn.Module): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"]