mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-13 19:20:09 +00:00
Using fa3 through option '--enable-fa3'
This commit is contained in:
parent
7ae8070fd9
commit
15e4e0bdb8
21
generate.py
21
generate.py
@ -637,6 +637,11 @@ def _validate_args(args):
|
|||||||
args.
|
args.
|
||||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
assert torch.cuda.get_device_capability()[0] >= 9, (
|
||||||
|
"FlashAttention v3 requires SM >= 90. "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -804,6 +809,14 @@ def _parse_args():
|
|||||||
type=float,
|
type=float,
|
||||||
default= 0.2,
|
default= 0.2,
|
||||||
help="tea_cache threshold")
|
help="tea_cache threshold")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-fa3",
|
||||||
|
"--enable_fa3",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use Flash Attention 3 for attention layers or not."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -928,6 +941,10 @@ def generate(args):
|
|||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
for block in wan_t2v.model.blocks:
|
||||||
|
block.self_attn.__class__.enable_fa3 = True
|
||||||
|
|
||||||
if args.enable_teacache:
|
if args.enable_teacache:
|
||||||
wan_t2v.__class__.generate = t2v_generate
|
wan_t2v.__class__.generate = t2v_generate
|
||||||
@ -1014,6 +1031,10 @@ def generate(args):
|
|||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
for block in wan_i2v.model.blocks:
|
||||||
|
block.self_attn.__class__.enable_fa3 = True
|
||||||
|
|
||||||
if args.enable_teacache:
|
if args.enable_teacache:
|
||||||
wan_i2v.__class__.generate = i2v_generate
|
wan_i2v.__class__.generate = i2v_generate
|
||||||
|
@ -289,7 +289,7 @@ def usp_attn_forward(self,
|
|||||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||||
|
|
||||||
if FLASH_ATTN_3_AVAILABLE:
|
if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE:
|
||||||
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
||||||
None,
|
None,
|
||||||
query=half(q),
|
query=half(q),
|
||||||
|
Loading…
Reference in New Issue
Block a user