mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Fix pb with Sage1
This commit is contained in:
parent
e0fdec88f4
commit
4b04e6971b
@ -119,6 +119,7 @@ def pay_attention(
|
||||
deterministic=False,
|
||||
dtype=torch.bfloat16,
|
||||
version=None,
|
||||
force_attention= None
|
||||
):
|
||||
"""
|
||||
q: [B, Lq, Nq, C1].
|
||||
@ -133,7 +134,8 @@ def pay_attention(
|
||||
deterministic: bool. If True, slightly slower and uses more memory.
|
||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
||||
"""
|
||||
attn = offload.shared_state["_attention"]
|
||||
|
||||
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
||||
q,k,v = qkv_list
|
||||
qkv_list.clear()
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
# compute attention
|
||||
p = self.attn_dropout if self.training else 0.0
|
||||
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2)
|
||||
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, force_attention="sdpa")
|
||||
x = x.reshape(b, s, c)
|
||||
|
||||
# output
|
||||
|
||||
Loading…
Reference in New Issue
Block a user