diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 372555ff..f10865cd 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + if q.device.type == 'mps': + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + dtype = q.dtype if shared.opts.upcast_attn: q, k = q.float(), k.float()