mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 11:10:11 +00:00
Merge pull request #5 from songruizz/release_0.0.3
Using fa3 through option '--enable-fa3'
This commit is contained in:
commit
4057dc55b3
21
generate.py
21
generate.py
@ -637,6 +637,11 @@ def _validate_args(args):
|
||||
args.
|
||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||
|
||||
if args.enable_fa3:
|
||||
assert torch.cuda.get_device_capability()[0] >= 9, (
|
||||
"FlashAttention v3 requires SM >= 90. "
|
||||
)
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -804,6 +809,14 @@ def _parse_args():
|
||||
type=float,
|
||||
default= 0.2,
|
||||
help="tea_cache threshold")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-fa3",
|
||||
"--enable_fa3",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use Flash Attention 3 for attention layers or not."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -928,6 +941,10 @@ def generate(args):
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
if args.enable_fa3:
|
||||
for block in wan_t2v.model.blocks:
|
||||
block.self_attn.__class__.enable_fa3 = True
|
||||
|
||||
if args.enable_teacache:
|
||||
wan_t2v.__class__.generate = t2v_generate
|
||||
@ -1014,6 +1031,10 @@ def generate(args):
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
if args.enable_fa3:
|
||||
for block in wan_i2v.model.blocks:
|
||||
block.self_attn.__class__.enable_fa3 = True
|
||||
|
||||
if args.enable_teacache:
|
||||
wan_i2v.__class__.generate = i2v_generate
|
||||
|
@ -289,7 +289,7 @@ def usp_attn_forward(self,
|
||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||
|
||||
if FLASH_ATTN_3_AVAILABLE:
|
||||
if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE:
|
||||
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
||||
None,
|
||||
query=half(q),
|
||||
|
Loading…
Reference in New Issue
Block a user