fix performance regression due to Flash Attention duplicated code

This commit is contained in:
songrui.771 2025-05-27 16:14:07 +08:00
parent 3c7c6f8b29
commit 2fc28056fa

View File

@ -304,13 +304,6 @@ def usp_attn_forward(self,
value=half(v),
window_size=self.window_size)
x = xFuserLongContextAttention()(
None,
query=half(q),
key=half(k),
value=half(v),
window_size=self.window_size)
# TODO: padding after attention.
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)