mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Fix flash attention
fa3 latest version changed the return shape of the varlen func to be consistent w fa2. this pr fixes the fa3 attention call as done in https://github.com/Wan-Video/Wan2.2/pull/64
This commit is contained in:
parent
7c81b2f27d
commit
ca23a2fc59
@ -107,7 +107,7 @@ def flash_attention(
|
||||
max_seqlen_k=lk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
||||
deterministic=deterministic).unflatten(0, (b, lq))
|
||||
else:
|
||||
assert FLASH_ATTN_2_AVAILABLE
|
||||
x = flash_attn.flash_attn_varlen_func(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user