From 4b04e6971bb79fedd7041ad5837e47ab439154d4 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 2 Mar 2025 22:16:30 +0100 Subject: [PATCH] Fix pb with Sage1 --- wan/modules/attention.py | 4 +++- wan/modules/clip.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 86aea95..eaeb6c5 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -119,6 +119,7 @@ def pay_attention( deterministic=False, dtype=torch.bfloat16, version=None, + force_attention= None ): """ q: [B, Lq, Nq, C1]. @@ -133,7 +134,8 @@ def pay_attention( deterministic: bool. If True, slightly slower and uses more memory. dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. """ - attn = offload.shared_state["_attention"] + + attn = offload.shared_state["_attention"] if force_attention== None else force_attention q,k,v = qkv_list qkv_list.clear() diff --git a/wan/modules/clip.py b/wan/modules/clip.py index f76fbee..6ffefe2 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -82,7 +82,7 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2) + x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, force_attention="sdpa") x = x.reshape(b, s, c) # output