Merge pull request #6 from songruizz/release_0.0.3

fix incorrect FA3 usage when both have FA and FA3 package with non-Hopper GPU
This commit is contained in:
zc8gerard 2025-05-30 00:28:33 +08:00 committed by GitHub
commit feba2a62c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -146,12 +146,22 @@ class WanSelfAttention(nn.Module):
q, k, v = qkv_fn(x)
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
if hasattr(self, 'enable_fa3') and self.enable_fa3:
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size,
version=3)
else:
x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size,
version=2)
# output
x = x.flatten(2)
@ -176,7 +186,10 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context).view(b, -1, n, d)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
if hasattr(self, 'enable_fa3') and self.enable_fa3:
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
else:
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
# output
x = x.flatten(2)
@ -217,9 +230,14 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context).view(b, -1, n, d)
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
v_img = self.v_img(context_img).view(b, -1, n, d)
img_x = flash_attention(q, k_img, v_img, k_lens=None)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens)
if hasattr(self, 'enable_fa3') and self.enable_fa3:
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=3)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
else:
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=2)
# compute attention
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
# output
x = x.flatten(2)