pass options to cross attention class

This commit is contained in:
lllyasviel 2024-03-08 00:50:29 -08:00
parent 10b5ca2541
commit 29be1da7cf

View File

@ -385,7 +385,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None, transformer_options=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
k = self.to_k(context) k = self.to_k(context)
@ -504,7 +504,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n) n = self.attn1.to_out(n)
else: else:
n = self.attn1(n, context=context_attn1, value=value_attn1) n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options)
if "attn1_output_patch" in transformer_patches: if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"] patch = transformer_patches["attn1_output_patch"]
@ -544,7 +544,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n) n = self.attn2.to_out(n)
else: else:
n = self.attn2(n, context=context_attn2, value=value_attn2) n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options)
if "attn2_output_patch" in transformer_patches: if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"] patch = transformer_patches["attn2_output_patch"]