mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Fix pb with Sage1
This commit is contained in:
parent
e0fdec88f4
commit
4b04e6971b
@ -119,6 +119,7 @@ def pay_attention(
|
|||||||
deterministic=False,
|
deterministic=False,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
version=None,
|
version=None,
|
||||||
|
force_attention= None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q: [B, Lq, Nq, C1].
|
q: [B, Lq, Nq, C1].
|
||||||
@ -133,7 +134,8 @@ def pay_attention(
|
|||||||
deterministic: bool. If True, slightly slower and uses more memory.
|
deterministic: bool. If True, slightly slower and uses more memory.
|
||||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||||
"""
|
"""
|
||||||
attn = offload.shared_state["_attention"]
|
|
||||||
|
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
||||||
q,k,v = qkv_list
|
q,k,v = qkv_list
|
||||||
qkv_list.clear()
|
qkv_list.clear()
|
||||||
|
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
p = self.attn_dropout if self.training else 0.0
|
p = self.attn_dropout if self.training else 0.0
|
||||||
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2)
|
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, force_attention="sdpa")
|
||||||
x = x.reshape(b, s, c)
|
x = x.reshape(b, s, c)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user