mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-14 03:30:10 +00:00
Merge pull request #5 from songruizz/release_0.0.3
Using fa3 through option '--enable-fa3'
This commit is contained in:
commit
4057dc55b3
21
generate.py
21
generate.py
@ -637,6 +637,11 @@ def _validate_args(args):
|
|||||||
args.
|
args.
|
||||||
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
assert torch.cuda.get_device_capability()[0] >= 9, (
|
||||||
|
"FlashAttention v3 requires SM >= 90. "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args():
|
def _parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -805,6 +810,14 @@ def _parse_args():
|
|||||||
default= 0.2,
|
default= 0.2,
|
||||||
help="tea_cache threshold")
|
help="tea_cache threshold")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-fa3",
|
||||||
|
"--enable_fa3",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Use Flash Attention 3 for attention layers or not."
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
_validate_args(args)
|
_validate_args(args)
|
||||||
@ -929,6 +942,10 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
for block in wan_t2v.model.blocks:
|
||||||
|
block.self_attn.__class__.enable_fa3 = True
|
||||||
|
|
||||||
if args.enable_teacache:
|
if args.enable_teacache:
|
||||||
wan_t2v.__class__.generate = t2v_generate
|
wan_t2v.__class__.generate = t2v_generate
|
||||||
wan_t2v.model.__class__.enable_teacache = True
|
wan_t2v.model.__class__.enable_teacache = True
|
||||||
@ -1015,6 +1032,10 @@ def generate(args):
|
|||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.enable_fa3:
|
||||||
|
for block in wan_i2v.model.blocks:
|
||||||
|
block.self_attn.__class__.enable_fa3 = True
|
||||||
|
|
||||||
if args.enable_teacache:
|
if args.enable_teacache:
|
||||||
wan_i2v.__class__.generate = i2v_generate
|
wan_i2v.__class__.generate = i2v_generate
|
||||||
wan_i2v.model.__class__.enable_teacache = True
|
wan_i2v.model.__class__.enable_teacache = True
|
||||||
|
@ -289,7 +289,7 @@ def usp_attn_forward(self,
|
|||||||
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
||||||
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
||||||
|
|
||||||
if FLASH_ATTN_3_AVAILABLE:
|
if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE:
|
||||||
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
|
||||||
None,
|
None,
|
||||||
query=half(q),
|
query=half(q),
|
||||||
|
Loading…
Reference in New Issue
Block a user