try solve dtype cast for #112
This commit is contained in:
parent
e1faf8327b
commit
c3a66b016b
@ -463,7 +463,7 @@ class CrossAttentionPatch:
|
||||
ip_k = ip_k * W
|
||||
ip_v = ip_v_offset + ip_v_mean * W
|
||||
|
||||
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
|
||||
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
|
||||
if weight_type.startswith("original"):
|
||||
out_ip = out_ip * weight
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user