pass options to cross attention class
This commit is contained in:
parent
10b5ca2541
commit
29be1da7cf
@ -385,7 +385,7 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||||
|
|
||||||
def forward(self, x, context=None, value=None, mask=None):
|
def forward(self, x, context=None, value=None, mask=None, transformer_options=None):
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
@ -504,7 +504,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
||||||
n = self.attn1.to_out(n)
|
n = self.attn1.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options)
|
||||||
|
|
||||||
if "attn1_output_patch" in transformer_patches:
|
if "attn1_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn1_output_patch"]
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
@ -544,7 +544,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
||||||
n = self.attn2.to_out(n)
|
n = self.attn2.to_out(n)
|
||||||
else:
|
else:
|
||||||
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options)
|
||||||
|
|
||||||
if "attn2_output_patch" in transformer_patches:
|
if "attn2_output_patch" in transformer_patches:
|
||||||
patch = transformer_patches["attn2_output_patch"]
|
patch = transformer_patches["attn2_output_patch"]
|
||||||
|
Loading…
Reference in New Issue
Block a user