mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge pull request #6 from songruizz/release_0.0.3
fix incorrect FA3 usage when both have FA and FA3 package with non-Hopper GPU
This commit is contained in:
		
						commit
						feba2a62c8
					
				@ -146,12 +146,22 @@ class WanSelfAttention(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        q, k, v = qkv_fn(x)
 | 
					        q, k, v = qkv_fn(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x = flash_attention(
 | 
					        if hasattr(self, 'enable_fa3') and self.enable_fa3:
 | 
				
			||||||
            q=rope_apply(q, grid_sizes, freqs),
 | 
					            x = flash_attention(
 | 
				
			||||||
            k=rope_apply(k, grid_sizes, freqs),
 | 
					                q=rope_apply(q, grid_sizes, freqs),
 | 
				
			||||||
            v=v,
 | 
					                k=rope_apply(k, grid_sizes, freqs),
 | 
				
			||||||
            k_lens=seq_lens,
 | 
					                v=v,
 | 
				
			||||||
            window_size=self.window_size)
 | 
					                k_lens=seq_lens,
 | 
				
			||||||
 | 
					                window_size=self.window_size,
 | 
				
			||||||
 | 
					                version=3)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            x = flash_attention(
 | 
				
			||||||
 | 
					                q=rope_apply(q, grid_sizes, freqs),
 | 
				
			||||||
 | 
					                k=rope_apply(k, grid_sizes, freqs),
 | 
				
			||||||
 | 
					                v=v,
 | 
				
			||||||
 | 
					                k_lens=seq_lens,
 | 
				
			||||||
 | 
					                window_size=self.window_size,
 | 
				
			||||||
 | 
					                version=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # output
 | 
					        # output
 | 
				
			||||||
        x = x.flatten(2)
 | 
					        x = x.flatten(2)
 | 
				
			||||||
@ -176,7 +186,10 @@ class WanT2VCrossAttention(WanSelfAttention):
 | 
				
			|||||||
        v = self.v(context).view(b, -1, n, d)
 | 
					        v = self.v(context).view(b, -1, n, d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # compute attention
 | 
					        # compute attention
 | 
				
			||||||
        x = flash_attention(q, k, v, k_lens=context_lens)
 | 
					        if hasattr(self, 'enable_fa3') and self.enable_fa3:
 | 
				
			||||||
 | 
					            x = flash_attention(q, k, v, k_lens=context_lens, version=3)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            x = flash_attention(q, k, v, k_lens=context_lens, version=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # output
 | 
					        # output
 | 
				
			||||||
        x = x.flatten(2)
 | 
					        x = x.flatten(2)
 | 
				
			||||||
@ -217,9 +230,14 @@ class WanI2VCrossAttention(WanSelfAttention):
 | 
				
			|||||||
        v = self.v(context).view(b, -1, n, d)
 | 
					        v = self.v(context).view(b, -1, n, d)
 | 
				
			||||||
        k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
 | 
					        k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
 | 
				
			||||||
        v_img = self.v_img(context_img).view(b, -1, n, d)
 | 
					        v_img = self.v_img(context_img).view(b, -1, n, d)
 | 
				
			||||||
        img_x = flash_attention(q, k_img, v_img, k_lens=None)
 | 
					        if hasattr(self, 'enable_fa3') and self.enable_fa3:
 | 
				
			||||||
        # compute attention
 | 
					            img_x = flash_attention(q, k_img, v_img, k_lens=None, version=3)
 | 
				
			||||||
        x = flash_attention(q, k, v, k_lens=context_lens)
 | 
					            # compute attention
 | 
				
			||||||
 | 
					            x = flash_attention(q, k, v, k_lens=context_lens, version=3)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            img_x = flash_attention(q, k_img, v_img, k_lens=None, version=2)
 | 
				
			||||||
 | 
					            # compute attention
 | 
				
			||||||
 | 
					            x = flash_attention(q, k, v, k_lens=context_lens, version=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # output
 | 
					        # output
 | 
				
			||||||
        x = x.flatten(2)
 | 
					        x = x.flatten(2)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user