From 15e4e0bdb81e57ccb73b068cbe20d36a8ad9151a Mon Sep 17 00:00:00 2001 From: songruizz Date: Thu, 29 May 2025 11:42:34 +0800 Subject: [PATCH] Using fa3 through option '--enable-fa3' --- generate.py | 21 +++++++++++++++++++++ wan/distributed/xdit_context_parallel.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/generate.py b/generate.py index ce19aae..739de46 100644 --- a/generate.py +++ b/generate.py @@ -637,6 +637,11 @@ def _validate_args(args): args. task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + if args.enable_fa3: + assert torch.cuda.get_device_capability()[0] >= 9, ( + "FlashAttention v3 requires SM >= 90. " + ) + def _parse_args(): parser = argparse.ArgumentParser( @@ -804,6 +809,14 @@ def _parse_args(): type=float, default= 0.2, help="tea_cache threshold") + + parser.add_argument( + "--enable-fa3", + "--enable_fa3", + action="store_true", + default=False, + help="Use Flash Attention 3 for attention layers or not." + ) args = parser.parse_args() @@ -928,6 +941,10 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) + + if args.enable_fa3: + for block in wan_t2v.model.blocks: + block.self_attn.__class__.enable_fa3 = True if args.enable_teacache: wan_t2v.__class__.generate = t2v_generate @@ -1014,6 +1031,10 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) + + if args.enable_fa3: + for block in wan_i2v.model.blocks: + block.self_attn.__class__.enable_fa3 = True if args.enable_teacache: wan_i2v.__class__.generate = i2v_generate diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index b264d84..7871d6b 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -289,7 +289,7 @@ def usp_attn_forward(self, # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) - if FLASH_ATTN_3_AVAILABLE: + if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE: x = xFuserLongContextAttention(attn_type=AttnType.FA3)( None, query=half(q),