mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 11:10:11 +00:00
fix incorrect FA3 usage when both have FA and FA3 package with non-Hopper GPU
This commit is contained in:
parent
15e4e0bdb8
commit
2281955f45
@ -146,12 +146,22 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
q, k, v = qkv_fn(x)
|
||||
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size)
|
||||
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
version=3)
|
||||
else:
|
||||
x = flash_attention(
|
||||
q=rope_apply(q, grid_sizes, freqs),
|
||||
k=rope_apply(k, grid_sizes, freqs),
|
||||
v=v,
|
||||
k_lens=seq_lens,
|
||||
window_size=self.window_size,
|
||||
version=2)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
@ -176,7 +186,10 @@ class WanT2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
|
||||
else:
|
||||
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
@ -217,9 +230,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
||||
v = self.v(context).view(b, -1, n, d)
|
||||
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
||||
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
||||
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=3)
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
|
||||
else:
|
||||
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=2)
|
||||
# compute attention
|
||||
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
|
||||
|
||||
# output
|
||||
x = x.flatten(2)
|
||||
|
Loading…
Reference in New Issue
Block a user