Fix pb with Sage1

This commit is contained in:
DeepBeepMeep 2025-03-02 22:16:30 +01:00
parent e0fdec88f4
commit 4b04e6971b
2 changed files with 4 additions and 2 deletions

View File

@ -119,6 +119,7 @@ def pay_attention(
deterministic=False,
dtype=torch.bfloat16,
version=None,
force_attention= None
):
"""
q: [B, Lq, Nq, C1].
@ -133,7 +134,8 @@ def pay_attention(
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
attn = offload.shared_state["_attention"]
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
q,k,v = qkv_list
qkv_list.clear()

View File

@ -82,7 +82,7 @@ class SelfAttention(nn.Module):
# compute attention
p = self.attn_dropout if self.training else 0.0
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2)
x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, force_attention="sdpa")
x = x.reshape(b, s, c)
# output