From 2281955f45b561948b07514b8be3c584bbb9eef7 Mon Sep 17 00:00:00 2001 From: songruizz Date: Thu, 29 May 2025 23:48:06 +0800 Subject: [PATCH] fix incorrect FA3 usage when both have FA and FA3 package with non-Hopper GPU --- wan/modules/model.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/wan/modules/model.py b/wan/modules/model.py index a5425da..d5127fd 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -146,12 +146,22 @@ class WanSelfAttention(nn.Module): q, k, v = qkv_fn(x) - x = flash_attention( - q=rope_apply(q, grid_sizes, freqs), - k=rope_apply(k, grid_sizes, freqs), - v=v, - k_lens=seq_lens, - window_size=self.window_size) + if hasattr(self, 'enable_fa3') and self.enable_fa3: + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + version=3) + else: + x = flash_attention( + q=rope_apply(q, grid_sizes, freqs), + k=rope_apply(k, grid_sizes, freqs), + v=v, + k_lens=seq_lens, + window_size=self.window_size, + version=2) # output x = x.flatten(2) @@ -176,7 +186,10 @@ class WanT2VCrossAttention(WanSelfAttention): v = self.v(context).view(b, -1, n, d) # compute attention - x = flash_attention(q, k, v, k_lens=context_lens) + if hasattr(self, 'enable_fa3') and self.enable_fa3: + x = flash_attention(q, k, v, k_lens=context_lens, version=3) + else: + x = flash_attention(q, k, v, k_lens=context_lens, version=2) # output x = x.flatten(2) @@ -217,9 +230,14 @@ class WanI2VCrossAttention(WanSelfAttention): v = self.v(context).view(b, -1, n, d) k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) v_img = self.v_img(context_img).view(b, -1, n, d) - img_x = flash_attention(q, k_img, v_img, k_lens=None) - # compute attention - x = flash_attention(q, k, v, k_lens=context_lens) + if hasattr(self, 'enable_fa3') and self.enable_fa3: + img_x = flash_attention(q, k_img, v_img, k_lens=None, version=3) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens, version=3) + else: + img_x = flash_attention(q, k_img, v_img, k_lens=None, version=2) + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens, version=2) # output x = x.flatten(2)