mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	add support for FA3 on multi-gpu inference
This commit is contained in:
		
							parent
							
								
									e5a741309d
								
							
						
					
					
						commit
						4c35b3fd58
					
				@ -1,6 +1,10 @@
 | 
				
			|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
					# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.cuda.amp as amp
 | 
					import torch.cuda.amp as amp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import yunchang
 | 
				
			||||||
 | 
					from yunchang.kernels import AttnType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from xfuser.core.distributed import (
 | 
					from xfuser.core.distributed import (
 | 
				
			||||||
    get_sequence_parallel_rank,
 | 
					    get_sequence_parallel_rank,
 | 
				
			||||||
    get_sequence_parallel_world_size,
 | 
					    get_sequence_parallel_world_size,
 | 
				
			||||||
@ -10,6 +14,12 @@ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from ..modules.model import sinusoidal_embedding_1d
 | 
					from ..modules.model import sinusoidal_embedding_1d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    import flash_attn_interface
 | 
				
			||||||
 | 
					    FLASH_ATTN_3_AVAILABLE = True
 | 
				
			||||||
 | 
					except ModuleNotFoundError:
 | 
				
			||||||
 | 
					    FLASH_ATTN_3_AVAILABLE = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def pad_freqs(original_tensor, target_len):
 | 
					def pad_freqs(original_tensor, target_len):
 | 
				
			||||||
    seq_len, s1, s2 = original_tensor.shape
 | 
					    seq_len, s1, s2 = original_tensor.shape
 | 
				
			||||||
@ -210,6 +220,21 @@ def usp_attn_forward(self,
 | 
				
			|||||||
    #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
 | 
					    #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
 | 
				
			||||||
    #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
 | 
					    #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if FLASH_ATTN_3_AVAILABLE:
 | 
				
			||||||
 | 
					        x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					            query=half(q),
 | 
				
			||||||
 | 
					            key=half(k),
 | 
				
			||||||
 | 
					            value=half(v),
 | 
				
			||||||
 | 
					            window_size=self.window_size)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        x = xFuserLongContextAttention()(
 | 
				
			||||||
 | 
					            None,
 | 
				
			||||||
 | 
					            query=half(q),
 | 
				
			||||||
 | 
					            key=half(k),
 | 
				
			||||||
 | 
					            value=half(v),
 | 
				
			||||||
 | 
					            window_size=self.window_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    x = xFuserLongContextAttention()(
 | 
					    x = xFuserLongContextAttention()(
 | 
				
			||||||
        None,
 | 
					        None,
 | 
				
			||||||
        query=half(q),
 | 
					        query=half(q),
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user