Compare commits

...

2 Commits

Author SHA1 Message Date
Emanuele Bugliarello
7406d62afa
Merge ca23a2fc59 into e4f90fa81f 2025-12-11 18:13:39 +08:00
Emanuele Bugliarello
ca23a2fc59
Fix flash attention
fa3 latest version changed the return shape of the varlen func to be consistent w fa2. this pr fixes the fa3 attention call as done in https://github.com/Wan-Video/Wan2.2/pull/64
2025-08-27 11:43:35 +02:00

View File

@ -107,7 +107,7 @@ def flash_attention(
max_seqlen_k=lk, max_seqlen_k=lk,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq)) deterministic=deterministic).unflatten(0, (b, lq))
else: else:
assert FLASH_ATTN_2_AVAILABLE assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func( x = flash_attn.flash_attn_varlen_func(