diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 86aea95..eaeb6c5 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -119,6 +119,7 @@ def pay_attention( deterministic=False, dtype=torch.bfloat16, version=None, + force_attention= None ): """ q: [B, Lq, Nq, C1]. @@ -133,7 +134,8 @@ def pay_attention( deterministic: bool. If True, slightly slower and uses more memory. dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. """ - attn = offload.shared_state["_attention"] + + attn = offload.shared_state["_attention"] if force_attention== None else force_attention q,k,v = qkv_list qkv_list.clear() diff --git a/wan/modules/clip.py b/wan/modules/clip.py index f76fbee..6ffefe2 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -82,7 +82,7 @@ class SelfAttention(nn.Module): # compute attention p = self.attn_dropout if self.training else 0.0 - x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, version=2) + x = pay_attention([q, k, v], dropout_p=p, causal=self.causal, force_attention="sdpa") x = x.reshape(b, s, c) # output