fix flash attention

This commit is contained in:
DeepBeepMeep 2025-04-18 09:46:52 +02:00
parent e97599635d
commit a63aff0377

View File

@ -276,7 +276,7 @@ def pay_attention(
k=k, k=k,
v=v, v=v,
cu_seqlens_q= cu_seqlens_q, cu_seqlens_q= cu_seqlens_q,
cu_seqlens_kv= cu_seqlens_k, cu_seqlens_k= cu_seqlens_k,
seqused_q=None, seqused_q=None,
seqused_k=None, seqused_k=None,
max_seqlen_q=lq, max_seqlen_q=lq,
@ -289,8 +289,8 @@ def pay_attention(
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q= [0, lq], cu_seqlens_q= cu_seqlens_q,
cu_seqlens_kv=[0, lk], cu_seqlens_k= cu_seqlens_k,
max_seqlen_q=lq, max_seqlen_q=lq,
max_seqlen_k=lk, max_seqlen_k=lk,
dropout_p=dropout_p, dropout_p=dropout_p,