Fix flash attention

fa3 latest version changed the return shape of the varlen func to be consistent w fa2. this pr fixes the fa3 attention call as done in https://github.com/Wan-Video/Wan2.2/pull/64
This commit is contained in:
Emanuele Bugliarello 2025-08-27 11:43:35 +02:00 committed by GitHub
parent 7c81b2f27d
commit ca23a2fc59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,7 +107,7 @@ def flash_attention(
max_seqlen_k=lk, max_seqlen_k=lk,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq)) deterministic=deterministic).unflatten(0, (b, lq))
else: else:
assert FLASH_ATTN_2_AVAILABLE assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func( x = flash_attn.flash_attn_varlen_func(