try solve dtype cast for #112

This commit is contained in:
lllyasviel 2024-02-07 12:39:55 -08:00
parent e1faf8327b
commit c3a66b016b

View File

@ -463,7 +463,7 @@ class CrossAttentionPatch:
ip_k = ip_k * W
ip_v = ip_v_offset + ip_v_mean * W
out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"])
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
if weight_type.startswith("original"):
out_ip = out_ip * weight