Merge pull request #5 from songruizz/release_0.0.3

Using fa3 through option '--enable-fa3'
This commit is contained in:
zc8gerard 2025-05-29 16:27:09 +08:00 committed by GitHub
commit 4057dc55b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 1 deletions

View File

@ -637,6 +637,11 @@ def _validate_args(args):
args.
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
if args.enable_fa3:
assert torch.cuda.get_device_capability()[0] >= 9, (
"FlashAttention v3 requires SM >= 90. "
)
def _parse_args():
parser = argparse.ArgumentParser(
@ -804,6 +809,14 @@ def _parse_args():
type=float,
default= 0.2,
help="tea_cache threshold")
parser.add_argument(
"--enable-fa3",
"--enable_fa3",
action="store_true",
default=False,
help="Use Flash Attention 3 for attention layers or not."
)
args = parser.parse_args()
@ -928,6 +941,10 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
if args.enable_fa3:
for block in wan_t2v.model.blocks:
block.self_attn.__class__.enable_fa3 = True
if args.enable_teacache:
wan_t2v.__class__.generate = t2v_generate
@ -1014,6 +1031,10 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
if args.enable_fa3:
for block in wan_i2v.model.blocks:
block.self_attn.__class__.enable_fa3 = True
if args.enable_teacache:
wan_i2v.__class__.generate = i2v_generate

View File

@ -289,7 +289,7 @@ def usp_attn_forward(self,
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
if FLASH_ATTN_3_AVAILABLE:
if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE:
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
None,
query=half(q),