mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Fix flash attention
fa3 latest version changed the return shape of the varlen func to be consistent w fa2. this pr fixes the fa3 attention call as done in https://github.com/Wan-Video/Wan2.2/pull/64
This commit is contained in:
		
							parent
							
								
									7c81b2f27d
								
							
						
					
					
						commit
						ca23a2fc59
					
				@ -107,7 +107,7 @@ def flash_attention(
 | 
			
		||||
            max_seqlen_k=lk,
 | 
			
		||||
            softmax_scale=softmax_scale,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            deterministic=deterministic)[0].unflatten(0, (b, lq))
 | 
			
		||||
            deterministic=deterministic).unflatten(0, (b, lq))
 | 
			
		||||
    else:
 | 
			
		||||
        assert FLASH_ATTN_2_AVAILABLE
 | 
			
		||||
        x = flash_attn.flash_attn_varlen_func(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user