mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge pull request #5 from songruizz/release_0.0.3
Using fa3 through option '--enable-fa3'
This commit is contained in:
		
						commit
						4057dc55b3
					
				
							
								
								
									
										21
									
								
								generate.py
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								generate.py
									
									
									
									
									
								
							@ -637,6 +637,11 @@ def _validate_args(args):
 | 
			
		||||
        args.
 | 
			
		||||
        task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
 | 
			
		||||
 | 
			
		||||
    if args.enable_fa3:
 | 
			
		||||
        assert torch.cuda.get_device_capability()[0] >= 9, (
 | 
			
		||||
            "FlashAttention v3 requires SM >= 90. "
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_args():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
@ -804,6 +809,14 @@ def _parse_args():
 | 
			
		||||
        type=float,
 | 
			
		||||
        default= 0.2,
 | 
			
		||||
        help="tea_cache threshold")
 | 
			
		||||
    
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--enable-fa3",
 | 
			
		||||
        "--enable_fa3",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="Use Flash Attention 3 for attention layers or not."
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
@ -928,6 +941,10 @@ def generate(args):
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if args.enable_fa3:
 | 
			
		||||
            for block in wan_t2v.model.blocks:
 | 
			
		||||
                block.self_attn.__class__.enable_fa3 = True
 | 
			
		||||
        
 | 
			
		||||
        if args.enable_teacache:
 | 
			
		||||
            wan_t2v.__class__.generate = t2v_generate
 | 
			
		||||
@ -1014,6 +1031,10 @@ def generate(args):
 | 
			
		||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
			
		||||
            t5_cpu=args.t5_cpu,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if args.enable_fa3:
 | 
			
		||||
            for block in wan_i2v.model.blocks:
 | 
			
		||||
                block.self_attn.__class__.enable_fa3 = True
 | 
			
		||||
        
 | 
			
		||||
        if args.enable_teacache:
 | 
			
		||||
            wan_i2v.__class__.generate = i2v_generate
 | 
			
		||||
 | 
			
		||||
@ -289,7 +289,7 @@ def usp_attn_forward(self,
 | 
			
		||||
    #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
 | 
			
		||||
    #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
    if FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
    if hasattr(self, 'enable_fa3') and self.enable_fa3 and FLASH_ATTN_3_AVAILABLE:
 | 
			
		||||
        x = xFuserLongContextAttention(attn_type=AttnType.FA3)(
 | 
			
		||||
            None,
 | 
			
		||||
            query=half(q),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user