mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-14 03:30:10 +00:00
Merge pull request #6 from songruizz/release_0.0.3
fix incorrect FA3 usage when both have FA and FA3 package with non-Hopper GPU
This commit is contained in:
commit
feba2a62c8
@ -146,12 +146,22 @@ class WanSelfAttention(nn.Module):
|
|||||||
|
|
||||||
q, k, v = qkv_fn(x)
|
q, k, v = qkv_fn(x)
|
||||||
|
|
||||||
|
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||||
x = flash_attention(
|
x = flash_attention(
|
||||||
q=rope_apply(q, grid_sizes, freqs),
|
q=rope_apply(q, grid_sizes, freqs),
|
||||||
k=rope_apply(k, grid_sizes, freqs),
|
k=rope_apply(k, grid_sizes, freqs),
|
||||||
v=v,
|
v=v,
|
||||||
k_lens=seq_lens,
|
k_lens=seq_lens,
|
||||||
window_size=self.window_size)
|
window_size=self.window_size,
|
||||||
|
version=3)
|
||||||
|
else:
|
||||||
|
x = flash_attention(
|
||||||
|
q=rope_apply(q, grid_sizes, freqs),
|
||||||
|
k=rope_apply(k, grid_sizes, freqs),
|
||||||
|
v=v,
|
||||||
|
k_lens=seq_lens,
|
||||||
|
window_size=self.window_size,
|
||||||
|
version=2)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
@ -176,7 +186,10 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context).view(b, -1, n, d)
|
v = self.v(context).view(b, -1, n, d)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||||
|
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
|
||||||
|
else:
|
||||||
|
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
@ -217,9 +230,14 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context).view(b, -1, n, d)
|
v = self.v(context).view(b, -1, n, d)
|
||||||
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
||||||
v_img = self.v_img(context_img).view(b, -1, n, d)
|
v_img = self.v_img(context_img).view(b, -1, n, d)
|
||||||
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
if hasattr(self, 'enable_fa3') and self.enable_fa3:
|
||||||
|
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=3)
|
||||||
# compute attention
|
# compute attention
|
||||||
x = flash_attention(q, k, v, k_lens=context_lens)
|
x = flash_attention(q, k, v, k_lens=context_lens, version=3)
|
||||||
|
else:
|
||||||
|
img_x = flash_attention(q, k_img, v_img, k_lens=None, version=2)
|
||||||
|
# compute attention
|
||||||
|
x = flash_attention(q, k, v, k_lens=context_lens, version=2)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
|
Loading…
Reference in New Issue
Block a user