mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge fdbc5f0588 into 7c81b2f27d
				
					
				
			This commit is contained in:
		
						commit
						9146033b23
					
				
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -34,4 +34,6 @@ Wan2.1-T2V-14B/
 | 
				
			|||||||
Wan2.1-T2V-1.3B/
 | 
					Wan2.1-T2V-1.3B/
 | 
				
			||||||
Wan2.1-I2V-14B-480P/
 | 
					Wan2.1-I2V-14B-480P/
 | 
				
			||||||
Wan2.1-I2V-14B-720P/
 | 
					Wan2.1-I2V-14B-720P/
 | 
				
			||||||
poetry.lock
 | 
					poetry.lock
 | 
				
			||||||
 | 
					wok/37ec512624d61f7aa208f7ea8140a131f93afc9a
 | 
				
			||||||
 | 
					wok/t2v-1.3b
 | 
				
			||||||
							
								
								
									
										90
									
								
								PR_SUMMARY.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								PR_SUMMARY.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					# Pull Request Summary
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Title
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					feat: add --vae_cpu flag for improved VRAM optimization on consumer GPUs
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Description
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Problem
 | 
				
			||||||
 | 
					Users with consumer-grade GPUs (like RTX 4090 with 11.49 GB VRAM) encounter OOM errors when running the T2V-1.3B model even with existing optimization flags (`--offload_model True --t5_cpu`). The OOM occurs because the VAE remains on GPU throughout the entire generation pipeline despite only being needed briefly for encoding/decoding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Solution
 | 
				
			||||||
 | 
					This PR adds a `--vae_cpu` flag that works similarly to the existing `--t5_cpu` flag. When enabled:
 | 
				
			||||||
 | 
					- VAE initializes on CPU instead of GPU
 | 
				
			||||||
 | 
					- VAE moves to GPU only when needed for encode/decode operations
 | 
				
			||||||
 | 
					- VAE returns to CPU after use, freeing VRAM for other models
 | 
				
			||||||
 | 
					- Saves ~100-200MB VRAM without performance degradation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Implementation Details
 | 
				
			||||||
 | 
					1. **Added `--vae_cpu` argument** to `generate.py` (mirrors `--t5_cpu` pattern)
 | 
				
			||||||
 | 
					2. **Updated all 4 pipelines**: WanT2V, WanI2V, WanFLF2V, WanVace
 | 
				
			||||||
 | 
					3. **Fixed critical DiT offloading**: When `offload_model=True` and `t5_cpu=False`, DiT now offloads before T5 loads to prevent OOM
 | 
				
			||||||
 | 
					4. **Handled VAE scale tensors**: Ensured `mean` and `std` tensors move with the model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Test Results
 | 
				
			||||||
 | 
					**Hardware:** RTX-class GPU with 11.49 GB VRAM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Test | Flags | Result | Notes |
 | 
				
			||||||
 | 
					|------|-------|--------|-------|
 | 
				
			||||||
 | 
					| Baseline | None | ❌ OOM | Failed at T5 load, needed 80MB but only 85MB free |
 | 
				
			||||||
 | 
					| `--vae_cpu` | VAE offload only | ✅ Success | Fixed the OOM issue |
 | 
				
			||||||
 | 
					| `--t5_cpu` | T5 offload only | ✅ Success | Also works |
 | 
				
			||||||
 | 
					| Both | `--vae_cpu --t5_cpu` | ✅ Success | Maximum VRAM savings |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Usage Examples
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Before (OOM on consumer GPUs):**
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b \
 | 
				
			||||||
 | 
					  --offload_model True --prompt "your prompt"
 | 
				
			||||||
 | 
					# Result: OOM Error
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**After (works on consumer GPUs):**
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b \
 | 
				
			||||||
 | 
					  --offload_model True --vae_cpu --prompt "your prompt"
 | 
				
			||||||
 | 
					# Result: Success!
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Maximum VRAM savings:**
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					python generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b \
 | 
				
			||||||
 | 
					  --offload_model True --vae_cpu --t5_cpu --prompt "your prompt"
 | 
				
			||||||
 | 
					# Result: Success with lowest memory footprint
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Benefits
 | 
				
			||||||
 | 
					1. ✅ Enables T2V-1.3B on more consumer GPUs without OOM
 | 
				
			||||||
 | 
					2. ✅ Backward compatible (default=False, no behavior change)
 | 
				
			||||||
 | 
					3. ✅ Consistent with existing `--t5_cpu` pattern
 | 
				
			||||||
 | 
					4. ✅ Works across all 4 pipelines (T2V, I2V, FLF2V, VACE)
 | 
				
			||||||
 | 
					5. ✅ No performance degradation (same math, just different memory placement)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Files Modified
 | 
				
			||||||
 | 
					- `generate.py` - Added `--vae_cpu` argument
 | 
				
			||||||
 | 
					- `wan/text2video.py` - WanT2V pipeline with conditional VAE offloading
 | 
				
			||||||
 | 
					- `wan/image2video.py` - WanI2V pipeline with conditional VAE offloading
 | 
				
			||||||
 | 
					- `wan/first_last_frame2video.py` - WanFLF2V pipeline with conditional VAE offloading
 | 
				
			||||||
 | 
					- `wan/vace.py` - WanVace pipeline with conditional VAE offloading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Related
 | 
				
			||||||
 | 
					This extends the existing OOM mitigation mentioned in the README (line 168-172) for RTX 4090 users.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Optional: Documentation Update
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Consider updating the README.md section on OOM handling:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Current (line 168-172):**
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True` and `--t5_cpu` options to reduce GPU memory usage.
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Suggested addition:**
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model True`, `--t5_cpu`, and `--vae_cpu` options to reduce GPU memory usage. For maximum VRAM savings, use all three flags together.
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										143
									
								
								VAE_OFFLOAD_PLAN.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								VAE_OFFLOAD_PLAN.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,143 @@
 | 
				
			|||||||
 | 
					# VAE Offloading Implementation & Testing Plan
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Overview
 | 
				
			||||||
 | 
					Add `--vae_cpu` flag to enable VAE offloading to save ~100-200MB VRAM during text-to-video generation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Implementation Plan
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Phase 1: Code Changes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**1. Add `--vae_cpu` flag to generate.py**
 | 
				
			||||||
 | 
					- Add argument to parser (similar to `--t5_cpu`)
 | 
				
			||||||
 | 
					- Default: `False` (maintain current upstream behavior)
 | 
				
			||||||
 | 
					- Pass to pipeline constructors
 | 
				
			||||||
 | 
					- Independent flag (works regardless of `offload_model` setting)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**2. Update Pipeline Constructors**
 | 
				
			||||||
 | 
					- Add `vae_cpu` parameter to `__init__` methods in:
 | 
				
			||||||
 | 
					  - `WanT2V` (text2video.py)
 | 
				
			||||||
 | 
					  - `WanI2V` (image2video.py)
 | 
				
			||||||
 | 
					  - `WanFLF2V` (first_last_frame2video.py)
 | 
				
			||||||
 | 
					  - `WanVace` (vace.py)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**3. Conditional VAE Initialization**
 | 
				
			||||||
 | 
					- If `vae_cpu=True`: Initialize VAE on CPU
 | 
				
			||||||
 | 
					- If `vae_cpu=False`: Initialize VAE on GPU (current behavior)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**4. Update Offload Logic**
 | 
				
			||||||
 | 
					- Only move VAE to/from GPU when `vae_cpu=True`
 | 
				
			||||||
 | 
					- When `vae_cpu=False`, VAE stays on GPU (no extra transfers)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Phase 2: Testing Plan
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Test Scripts to Create:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					# wok/test1_baseline.sh - No flags (expect OOM)
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --prompt "..."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# wok/test2_vae_cpu.sh - Only VAE offloading
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --vae_cpu --prompt "..."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# wok/test3_t5_cpu.sh - Only T5 offloading
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --t5_cpu --prompt "..."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# wok/test4_both.sh - Both flags
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --vae_cpu --t5_cpu --prompt "..."
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Expected Results:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Test | Flags | Expected Outcome | Memory Peak |
 | 
				
			||||||
 | 
					|------|-------|------------------|-------------|
 | 
				
			||||||
 | 
					| 1 | None | ❌ OOM Error | ~VRAM_MAX + 100MB |
 | 
				
			||||||
 | 
					| 2 | `--vae_cpu` | ✅ Success | ~VRAM_MAX - 100-200MB |
 | 
				
			||||||
 | 
					| 3 | `--t5_cpu` | ? (might still OOM) | ~VRAM_MAX - 50MB |
 | 
				
			||||||
 | 
					| 4 | `--vae_cpu --t5_cpu` | ✅ Success | ~VRAM_MAX - 150-250MB |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Actual Test Results:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Hardware:** 11.49 GiB VRAM GPU
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					| Test | Flags | Actual Outcome | Notes |
 | 
				
			||||||
 | 
					|------|-------|----------------|-------|
 | 
				
			||||||
 | 
					| 1 | None | ❌ OOM Error | Failed trying to allocate 80MB, only 85.38MB free |
 | 
				
			||||||
 | 
					| 2 | `--vae_cpu` | ✅ Success | Completed successfully after fixes |
 | 
				
			||||||
 | 
					| 3 | `--t5_cpu` | ✅ Success | No OOM, completed successfully |
 | 
				
			||||||
 | 
					| 4 | `--vae_cpu --t5_cpu` | ✅ Success | Completed with maximum VRAM savings |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Key Findings:**
 | 
				
			||||||
 | 
					- Baseline OOM occurred when trying to move T5 to GPU with DiT already loaded
 | 
				
			||||||
 | 
					- VAE offloading alone is sufficient to fix the OOM
 | 
				
			||||||
 | 
					- T5 offloading alone is also sufficient (surprising but effective!)
 | 
				
			||||||
 | 
					- Both flags together provide maximum VRAM savings for users with limited GPU memory
 | 
				
			||||||
 | 
					- All approaches work by freeing VRAM at critical moments during the pipeline execution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Conclusion:**
 | 
				
			||||||
 | 
					The `--vae_cpu` flag is a valuable addition for consumer GPU users, complementing the existing `--t5_cpu` optimization and following the same design pattern.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Phase 3: Documentation & PR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. Results Document
 | 
				
			||||||
 | 
					- Memory usage for each test
 | 
				
			||||||
 | 
					- Performance impact (if any) from CPU↔GPU transfers
 | 
				
			||||||
 | 
					- Recommendations for users
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. PR Components
 | 
				
			||||||
 | 
					- Feature description
 | 
				
			||||||
 | 
					- Memory savings benchmarks
 | 
				
			||||||
 | 
					- Backward compatible (default=False)
 | 
				
			||||||
 | 
					- Use cases: when to enable `--vae_cpu`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Design Decisions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. **Independence**: `vae_cpu` works independently of `offload_model` flag (mirrors `t5_cpu` behavior)
 | 
				
			||||||
 | 
					2. **Default False**: Maintains current upstream behavior for backward compatibility
 | 
				
			||||||
 | 
					3. **Conditional Transfers**: Only add GPU↔CPU transfers when flag is enabled
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Memory Analysis
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Current Pipeline Memory Timeline:**
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Init:    [T5-CPU] [VAE-GPU] [DiT-GPU]  <- OOM here during init!
 | 
				
			||||||
 | 
					Encode:  [T5-GPU] [VAE-GPU] [DiT-GPU]
 | 
				
			||||||
 | 
					Loop:    [T5-CPU] [VAE-GPU] [DiT-GPU]  <- VAE not needed but wasting VRAM
 | 
				
			||||||
 | 
					Decode:  [T5-CPU] [VAE-GPU] [DiT-CPU]  <- Only now is VAE actually used
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**With `--vae_cpu` Enabled:**
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Init:    [T5-CPU] [VAE-CPU] [DiT-GPU]  <- VAE no longer occupying VRAM
 | 
				
			||||||
 | 
					Encode:  [T5-GPU] [VAE-CPU] [DiT-GPU]
 | 
				
			||||||
 | 
					Loop:    [T5-CPU] [VAE-CPU] [DiT-GPU]  <- VAE stays on CPU during loop
 | 
				
			||||||
 | 
					Decode:  [T5-CPU] [VAE-GPU] [DiT-CPU]  <- VAE moved to GPU only for decode
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Implementation Details
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Critical Fixes Applied:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. **DiT Offloading Before T5 Load** (when `offload_model=True` and `t5_cpu=False`)
 | 
				
			||||||
 | 
					   - DiT must be offloaded to CPU before loading T5 to GPU
 | 
				
			||||||
 | 
					   - Otherwise T5 allocation fails with OOM
 | 
				
			||||||
 | 
					   - Added automatic DiT→CPU before T5→GPU transition
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					2. **VAE Scale Tensors** (when `vae_cpu=True`)
 | 
				
			||||||
 | 
					   - VAE wrapper class stores `mean` and `std` tensors separately
 | 
				
			||||||
 | 
					   - These don't move with `.model.to(device)`
 | 
				
			||||||
 | 
					   - Must explicitly move scale tensors along with model
 | 
				
			||||||
 | 
					   - Fixed in all encode/decode operations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					3. **Conditional Offloading Logic**
 | 
				
			||||||
 | 
					   - VAE offloading only triggers when `vae_cpu=True`
 | 
				
			||||||
 | 
					   - Works independently of `offload_model` flag
 | 
				
			||||||
 | 
					   - Mirrors `t5_cpu` behavior for consistency
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Files Modified
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					1. `generate.py` - Add argument parser
 | 
				
			||||||
 | 
					2. `wan/text2video.py` - WanT2V pipeline
 | 
				
			||||||
 | 
					3. `wan/image2video.py` - WanI2V pipeline
 | 
				
			||||||
 | 
					4. `wan/first_last_frame2video.py` - WanFLF2V pipeline
 | 
				
			||||||
 | 
					5. `wan/vace.py` - WanVace pipeline
 | 
				
			||||||
 | 
					6. `wok/test*.sh` - Test scripts
 | 
				
			||||||
							
								
								
									
										24
									
								
								environment.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								environment.yml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					name: wan21
 | 
				
			||||||
 | 
					channels:
 | 
				
			||||||
 | 
					  - conda-forge
 | 
				
			||||||
 | 
					  - defaults
 | 
				
			||||||
 | 
					dependencies:
 | 
				
			||||||
 | 
					  - python>=3.10
 | 
				
			||||||
 | 
					  - pytorch>=2.4.0
 | 
				
			||||||
 | 
					  - torchvision>=0.19.0
 | 
				
			||||||
 | 
					  - tqdm
 | 
				
			||||||
 | 
					  - imageio
 | 
				
			||||||
 | 
					  - imageio-ffmpeg
 | 
				
			||||||
 | 
					  - numpy>=1.23.5,<2
 | 
				
			||||||
 | 
					  - pip
 | 
				
			||||||
 | 
					  - pip:
 | 
				
			||||||
 | 
					    - opencv-python>=4.9.0.80
 | 
				
			||||||
 | 
					    - diffusers>=0.31.0
 | 
				
			||||||
 | 
					    - transformers>=4.49.0
 | 
				
			||||||
 | 
					    - tokenizers>=0.20.3
 | 
				
			||||||
 | 
					    - accelerate>=1.1.1
 | 
				
			||||||
 | 
					    - easydict
 | 
				
			||||||
 | 
					    - ftfy
 | 
				
			||||||
 | 
					    - dashscope
 | 
				
			||||||
 | 
					    - flash_attn
 | 
				
			||||||
 | 
					    - gradio>=5.0.0
 | 
				
			||||||
@ -150,6 +150,11 @@ def _parse_args():
 | 
				
			|||||||
        action="store_true",
 | 
					        action="store_true",
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
        help="Whether to place T5 model on CPU.")
 | 
					        help="Whether to place T5 model on CPU.")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--vae_cpu",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        help="Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed for encoding/decoding.")
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--dit_fsdp",
 | 
					        "--dit_fsdp",
 | 
				
			||||||
        action="store_true",
 | 
					        action="store_true",
 | 
				
			||||||
@ -366,6 +371,7 @@ def generate(args):
 | 
				
			|||||||
            dit_fsdp=args.dit_fsdp,
 | 
					            dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					            t5_cpu=args.t5_cpu,
 | 
				
			||||||
 | 
					            vae_cpu=args.vae_cpu,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(
 | 
					        logging.info(
 | 
				
			||||||
@ -423,6 +429,7 @@ def generate(args):
 | 
				
			|||||||
            dit_fsdp=args.dit_fsdp,
 | 
					            dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					            t5_cpu=args.t5_cpu,
 | 
				
			||||||
 | 
					            vae_cpu=args.vae_cpu,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info("Generating video ...")
 | 
					        logging.info("Generating video ...")
 | 
				
			||||||
@ -481,6 +488,7 @@ def generate(args):
 | 
				
			|||||||
            dit_fsdp=args.dit_fsdp,
 | 
					            dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					            t5_cpu=args.t5_cpu,
 | 
				
			||||||
 | 
					            vae_cpu=args.vae_cpu,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info("Generating video ...")
 | 
					        logging.info("Generating video ...")
 | 
				
			||||||
@ -529,6 +537,7 @@ def generate(args):
 | 
				
			|||||||
            dit_fsdp=args.dit_fsdp,
 | 
					            dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					            t5_cpu=args.t5_cpu,
 | 
				
			||||||
 | 
					            vae_cpu=args.vae_cpu,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        src_video, src_mask, src_ref_images = wan_vace.prepare_source(
 | 
					        src_video, src_mask, src_ref_images = wan_vace.prepare_source(
 | 
				
			||||||
 | 
				
			|||||||
@ -41,6 +41,7 @@ class WanFLF2V:
 | 
				
			|||||||
        dit_fsdp=False,
 | 
					        dit_fsdp=False,
 | 
				
			||||||
        use_usp=False,
 | 
					        use_usp=False,
 | 
				
			||||||
        t5_cpu=False,
 | 
					        t5_cpu=False,
 | 
				
			||||||
 | 
					        vae_cpu=False,
 | 
				
			||||||
        init_on_cpu=True,
 | 
					        init_on_cpu=True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        r"""
 | 
					        r"""
 | 
				
			||||||
@ -63,6 +64,8 @@ class WanFLF2V:
 | 
				
			|||||||
                Enable distribution strategy of USP.
 | 
					                Enable distribution strategy of USP.
 | 
				
			||||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
					            t5_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
					                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
				
			||||||
 | 
					            vae_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
 | 
					                Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed.
 | 
				
			||||||
            init_on_cpu (`bool`, *optional*, defaults to True):
 | 
					            init_on_cpu (`bool`, *optional*, defaults to True):
 | 
				
			||||||
                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 | 
					                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -71,6 +74,7 @@ class WanFLF2V:
 | 
				
			|||||||
        self.rank = rank
 | 
					        self.rank = rank
 | 
				
			||||||
        self.use_usp = use_usp
 | 
					        self.use_usp = use_usp
 | 
				
			||||||
        self.t5_cpu = t5_cpu
 | 
					        self.t5_cpu = t5_cpu
 | 
				
			||||||
 | 
					        self.vae_cpu = vae_cpu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
					        self.num_train_timesteps = config.num_train_timesteps
 | 
				
			||||||
        self.param_dtype = config.param_dtype
 | 
					        self.param_dtype = config.param_dtype
 | 
				
			||||||
@ -87,9 +91,12 @@ class WanFLF2V:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.vae_stride = config.vae_stride
 | 
					        self.vae_stride = config.vae_stride
 | 
				
			||||||
        self.patch_size = config.patch_size
 | 
					        self.patch_size = config.patch_size
 | 
				
			||||||
 | 
					        # Initialize VAE on CPU if vae_cpu=True to save VRAM during pipeline initialization and diffusion loop
 | 
				
			||||||
 | 
					        # VAE is only needed for encoding first/last frames and decoding final latents
 | 
				
			||||||
 | 
					        vae_device = torch.device('cpu') if vae_cpu else self.device
 | 
				
			||||||
        self.vae = WanVAE(
 | 
					        self.vae = WanVAE(
 | 
				
			||||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
					            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
				
			||||||
            device=self.device)
 | 
					            device=vae_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.clip = CLIPModel(
 | 
					        self.clip = CLIPModel(
 | 
				
			||||||
            dtype=config.clip_dtype,
 | 
					            dtype=config.clip_dtype,
 | 
				
			||||||
@ -242,11 +249,16 @@ class WanFLF2V:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # preprocess
 | 
					        # preprocess
 | 
				
			||||||
        if not self.t5_cpu:
 | 
					        if not self.t5_cpu:
 | 
				
			||||||
 | 
					            # Offload DiT to CPU first if needed to make room for T5
 | 
				
			||||||
 | 
					            if offload_model:
 | 
				
			||||||
 | 
					                self.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            self.text_encoder.model.to(self.device)
 | 
					            self.text_encoder.model.to(self.device)
 | 
				
			||||||
            context = self.text_encoder([input_prompt], self.device)
 | 
					            context = self.text_encoder([input_prompt], self.device)
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
					            context_null = self.text_encoder([n_prompt], self.device)
 | 
				
			||||||
            if offload_model:
 | 
					            if offload_model:
 | 
				
			||||||
                self.text_encoder.model.cpu()
 | 
					                self.text_encoder.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
					            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
					            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
				
			||||||
@ -259,6 +271,13 @@ class WanFLF2V:
 | 
				
			|||||||
        if offload_model:
 | 
					        if offload_model:
 | 
				
			||||||
            self.clip.model.cpu()
 | 
					            self.clip.model.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Move VAE to GPU for encoding first and last frames if it's on CPU
 | 
				
			||||||
 | 
					        if self.vae_cpu:
 | 
				
			||||||
 | 
					            self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					            # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
        y = self.vae.encode([
 | 
					        y = self.vae.encode([
 | 
				
			||||||
            torch.concat([
 | 
					            torch.concat([
 | 
				
			||||||
                torch.nn.functional.interpolate(
 | 
					                torch.nn.functional.interpolate(
 | 
				
			||||||
@ -274,6 +293,12 @@ class WanFLF2V:
 | 
				
			|||||||
                         dim=1).to(self.device)
 | 
					                         dim=1).to(self.device)
 | 
				
			||||||
        ])[0]
 | 
					        ])[0]
 | 
				
			||||||
        y = torch.concat([msk, y])
 | 
					        y = torch.concat([msk, y])
 | 
				
			||||||
 | 
					        # Offload VAE back to CPU after encoding
 | 
				
			||||||
 | 
					        if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					            self.vae.model.cpu()
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @contextmanager
 | 
					        @contextmanager
 | 
				
			||||||
        def noop_no_sync():
 | 
					        def noop_no_sync():
 | 
				
			||||||
@ -364,7 +389,21 @@ class WanFLF2V:
 | 
				
			|||||||
                torch.cuda.empty_cache()
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.rank == 0:
 | 
					            if self.rank == 0:
 | 
				
			||||||
 | 
					                # Move VAE to GPU for decoding if it's on CPU
 | 
				
			||||||
 | 
					                if self.vae_cpu:
 | 
				
			||||||
 | 
					                    self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					                    # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
                videos = self.vae.decode(x0)
 | 
					                videos = self.vae.decode(x0)
 | 
				
			||||||
 | 
					                # Offload VAE back to CPU after decoding to free VRAM
 | 
				
			||||||
 | 
					                if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					                    self.vae.model.cpu()
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        del noise, latent
 | 
					        del noise, latent
 | 
				
			||||||
        del sample_scheduler
 | 
					        del sample_scheduler
 | 
				
			||||||
 | 
				
			|||||||
@ -41,6 +41,7 @@ class WanI2V:
 | 
				
			|||||||
        dit_fsdp=False,
 | 
					        dit_fsdp=False,
 | 
				
			||||||
        use_usp=False,
 | 
					        use_usp=False,
 | 
				
			||||||
        t5_cpu=False,
 | 
					        t5_cpu=False,
 | 
				
			||||||
 | 
					        vae_cpu=False,
 | 
				
			||||||
        init_on_cpu=True,
 | 
					        init_on_cpu=True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        r"""
 | 
					        r"""
 | 
				
			||||||
@ -63,6 +64,8 @@ class WanI2V:
 | 
				
			|||||||
                Enable distribution strategy of USP.
 | 
					                Enable distribution strategy of USP.
 | 
				
			||||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
					            t5_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
					                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
				
			||||||
 | 
					            vae_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
 | 
					                Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed.
 | 
				
			||||||
            init_on_cpu (`bool`, *optional*, defaults to True):
 | 
					            init_on_cpu (`bool`, *optional*, defaults to True):
 | 
				
			||||||
                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 | 
					                Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -71,6 +74,7 @@ class WanI2V:
 | 
				
			|||||||
        self.rank = rank
 | 
					        self.rank = rank
 | 
				
			||||||
        self.use_usp = use_usp
 | 
					        self.use_usp = use_usp
 | 
				
			||||||
        self.t5_cpu = t5_cpu
 | 
					        self.t5_cpu = t5_cpu
 | 
				
			||||||
 | 
					        self.vae_cpu = vae_cpu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
					        self.num_train_timesteps = config.num_train_timesteps
 | 
				
			||||||
        self.param_dtype = config.param_dtype
 | 
					        self.param_dtype = config.param_dtype
 | 
				
			||||||
@ -87,9 +91,12 @@ class WanI2V:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.vae_stride = config.vae_stride
 | 
					        self.vae_stride = config.vae_stride
 | 
				
			||||||
        self.patch_size = config.patch_size
 | 
					        self.patch_size = config.patch_size
 | 
				
			||||||
 | 
					        # Initialize VAE on CPU if vae_cpu=True to save VRAM during pipeline initialization and diffusion loop
 | 
				
			||||||
 | 
					        # VAE is only needed for encoding input images and decoding final latents
 | 
				
			||||||
 | 
					        vae_device = torch.device('cpu') if vae_cpu else self.device
 | 
				
			||||||
        self.vae = WanVAE(
 | 
					        self.vae = WanVAE(
 | 
				
			||||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
					            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
				
			||||||
            device=self.device)
 | 
					            device=vae_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.clip = CLIPModel(
 | 
					        self.clip = CLIPModel(
 | 
				
			||||||
            dtype=config.clip_dtype,
 | 
					            dtype=config.clip_dtype,
 | 
				
			||||||
@ -221,11 +228,16 @@ class WanI2V:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # preprocess
 | 
					        # preprocess
 | 
				
			||||||
        if not self.t5_cpu:
 | 
					        if not self.t5_cpu:
 | 
				
			||||||
 | 
					            # Offload DiT to CPU first if needed to make room for T5
 | 
				
			||||||
 | 
					            if offload_model:
 | 
				
			||||||
 | 
					                self.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            self.text_encoder.model.to(self.device)
 | 
					            self.text_encoder.model.to(self.device)
 | 
				
			||||||
            context = self.text_encoder([input_prompt], self.device)
 | 
					            context = self.text_encoder([input_prompt], self.device)
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
					            context_null = self.text_encoder([n_prompt], self.device)
 | 
				
			||||||
            if offload_model:
 | 
					            if offload_model:
 | 
				
			||||||
                self.text_encoder.model.cpu()
 | 
					                self.text_encoder.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
					            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
					            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
				
			||||||
@ -237,6 +249,13 @@ class WanI2V:
 | 
				
			|||||||
        if offload_model:
 | 
					        if offload_model:
 | 
				
			||||||
            self.clip.model.cpu()
 | 
					            self.clip.model.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Move VAE to GPU for encoding input image if it's on CPU
 | 
				
			||||||
 | 
					        if self.vae_cpu:
 | 
				
			||||||
 | 
					            self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					            # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
        y = self.vae.encode([
 | 
					        y = self.vae.encode([
 | 
				
			||||||
            torch.concat([
 | 
					            torch.concat([
 | 
				
			||||||
                torch.nn.functional.interpolate(
 | 
					                torch.nn.functional.interpolate(
 | 
				
			||||||
@ -247,6 +266,12 @@ class WanI2V:
 | 
				
			|||||||
                         dim=1).to(self.device)
 | 
					                         dim=1).to(self.device)
 | 
				
			||||||
        ])[0]
 | 
					        ])[0]
 | 
				
			||||||
        y = torch.concat([msk, y])
 | 
					        y = torch.concat([msk, y])
 | 
				
			||||||
 | 
					        # Offload VAE back to CPU after encoding
 | 
				
			||||||
 | 
					        if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					            self.vae.model.cpu()
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        @contextmanager
 | 
					        @contextmanager
 | 
				
			||||||
        def noop_no_sync():
 | 
					        def noop_no_sync():
 | 
				
			||||||
@ -337,7 +362,21 @@ class WanI2V:
 | 
				
			|||||||
                torch.cuda.empty_cache()
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if self.rank == 0:
 | 
					            if self.rank == 0:
 | 
				
			||||||
 | 
					                # Move VAE to GPU for decoding if it's on CPU
 | 
				
			||||||
 | 
					                if self.vae_cpu:
 | 
				
			||||||
 | 
					                    self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					                    # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
                videos = self.vae.decode(x0)
 | 
					                videos = self.vae.decode(x0)
 | 
				
			||||||
 | 
					                # Offload VAE back to CPU after decoding to free VRAM
 | 
				
			||||||
 | 
					                if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					                    self.vae.model.cpu()
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        del noise, latent
 | 
					        del noise, latent
 | 
				
			||||||
        del sample_scheduler
 | 
					        del sample_scheduler
 | 
				
			||||||
 | 
				
			|||||||
@ -38,6 +38,7 @@ class WanT2V:
 | 
				
			|||||||
        dit_fsdp=False,
 | 
					        dit_fsdp=False,
 | 
				
			||||||
        use_usp=False,
 | 
					        use_usp=False,
 | 
				
			||||||
        t5_cpu=False,
 | 
					        t5_cpu=False,
 | 
				
			||||||
 | 
					        vae_cpu=False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        r"""
 | 
					        r"""
 | 
				
			||||||
        Initializes the Wan text-to-video generation model components.
 | 
					        Initializes the Wan text-to-video generation model components.
 | 
				
			||||||
@ -59,11 +60,14 @@ class WanT2V:
 | 
				
			|||||||
                Enable distribution strategy of USP.
 | 
					                Enable distribution strategy of USP.
 | 
				
			||||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
					            t5_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
					                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
				
			||||||
 | 
					            vae_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
 | 
					                Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.device = torch.device(f"cuda:{device_id}")
 | 
					        self.device = torch.device(f"cuda:{device_id}")
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        self.rank = rank
 | 
					        self.rank = rank
 | 
				
			||||||
        self.t5_cpu = t5_cpu
 | 
					        self.t5_cpu = t5_cpu
 | 
				
			||||||
 | 
					        self.vae_cpu = vae_cpu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
					        self.num_train_timesteps = config.num_train_timesteps
 | 
				
			||||||
        self.param_dtype = config.param_dtype
 | 
					        self.param_dtype = config.param_dtype
 | 
				
			||||||
@ -79,9 +83,12 @@ class WanT2V:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.vae_stride = config.vae_stride
 | 
					        self.vae_stride = config.vae_stride
 | 
				
			||||||
        self.patch_size = config.patch_size
 | 
					        self.patch_size = config.patch_size
 | 
				
			||||||
 | 
					        # Initialize VAE on CPU if vae_cpu=True to save VRAM during pipeline initialization and diffusion loop
 | 
				
			||||||
 | 
					        # VAE is only needed at the end for decoding latents to pixels
 | 
				
			||||||
 | 
					        vae_device = torch.device('cpu') if vae_cpu else self.device
 | 
				
			||||||
        self.vae = WanVAE(
 | 
					        self.vae = WanVAE(
 | 
				
			||||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
					            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
				
			||||||
            device=self.device)
 | 
					            device=vae_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
					        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
				
			||||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
					        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
				
			||||||
@ -172,11 +179,16 @@ class WanT2V:
 | 
				
			|||||||
        seed_g.manual_seed(seed)
 | 
					        seed_g.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not self.t5_cpu:
 | 
					        if not self.t5_cpu:
 | 
				
			||||||
 | 
					            # Offload DiT to CPU first if needed to make room for T5
 | 
				
			||||||
 | 
					            if offload_model:
 | 
				
			||||||
 | 
					                self.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            self.text_encoder.model.to(self.device)
 | 
					            self.text_encoder.model.to(self.device)
 | 
				
			||||||
            context = self.text_encoder([input_prompt], self.device)
 | 
					            context = self.text_encoder([input_prompt], self.device)
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
					            context_null = self.text_encoder([n_prompt], self.device)
 | 
				
			||||||
            if offload_model:
 | 
					            if offload_model:
 | 
				
			||||||
                self.text_encoder.model.cpu()
 | 
					                self.text_encoder.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
					            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
					            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
				
			||||||
@ -258,7 +270,21 @@ class WanT2V:
 | 
				
			|||||||
                self.model.cpu()
 | 
					                self.model.cpu()
 | 
				
			||||||
                torch.cuda.empty_cache()
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            if self.rank == 0:
 | 
					            if self.rank == 0:
 | 
				
			||||||
 | 
					                # Move VAE to GPU for decoding if it's on CPU
 | 
				
			||||||
 | 
					                if self.vae_cpu:
 | 
				
			||||||
 | 
					                    self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					                    # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
                videos = self.vae.decode(x0)
 | 
					                videos = self.vae.decode(x0)
 | 
				
			||||||
 | 
					                # Offload VAE back to CPU after decoding to free VRAM
 | 
				
			||||||
 | 
					                if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					                    self.vae.model.cpu()
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        del noise, latents
 | 
					        del noise, latents
 | 
				
			||||||
        del sample_scheduler
 | 
					        del sample_scheduler
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										41
									
								
								wan/vace.py
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								wan/vace.py
									
									
									
									
									
								
							@ -46,6 +46,7 @@ class WanVace(WanT2V):
 | 
				
			|||||||
        dit_fsdp=False,
 | 
					        dit_fsdp=False,
 | 
				
			||||||
        use_usp=False,
 | 
					        use_usp=False,
 | 
				
			||||||
        t5_cpu=False,
 | 
					        t5_cpu=False,
 | 
				
			||||||
 | 
					        vae_cpu=False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        r"""
 | 
					        r"""
 | 
				
			||||||
        Initializes the Wan text-to-video generation model components.
 | 
					        Initializes the Wan text-to-video generation model components.
 | 
				
			||||||
@ -67,11 +68,14 @@ class WanVace(WanT2V):
 | 
				
			|||||||
                Enable distribution strategy of USP.
 | 
					                Enable distribution strategy of USP.
 | 
				
			||||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
					            t5_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
					                Whether to place T5 model on CPU. Only works without t5_fsdp.
 | 
				
			||||||
 | 
					            vae_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
 | 
					                Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self.device = torch.device(f"cuda:{device_id}")
 | 
					        self.device = torch.device(f"cuda:{device_id}")
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        self.rank = rank
 | 
					        self.rank = rank
 | 
				
			||||||
        self.t5_cpu = t5_cpu
 | 
					        self.t5_cpu = t5_cpu
 | 
				
			||||||
 | 
					        self.vae_cpu = vae_cpu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
					        self.num_train_timesteps = config.num_train_timesteps
 | 
				
			||||||
        self.param_dtype = config.param_dtype
 | 
					        self.param_dtype = config.param_dtype
 | 
				
			||||||
@ -87,9 +91,12 @@ class WanVace(WanT2V):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.vae_stride = config.vae_stride
 | 
					        self.vae_stride = config.vae_stride
 | 
				
			||||||
        self.patch_size = config.patch_size
 | 
					        self.patch_size = config.patch_size
 | 
				
			||||||
 | 
					        # Initialize VAE on CPU if vae_cpu=True to save VRAM during pipeline initialization and diffusion loop
 | 
				
			||||||
 | 
					        # VAE is only needed for encoding frames/masks and decoding final latents
 | 
				
			||||||
 | 
					        vae_device = torch.device('cpu') if vae_cpu else self.device
 | 
				
			||||||
        self.vae = WanVAE(
 | 
					        self.vae = WanVAE(
 | 
				
			||||||
            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
					            vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 | 
				
			||||||
            device=self.device)
 | 
					            device=vae_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
 | 
					        logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
 | 
				
			||||||
        self.model = VaceWanModel.from_pretrained(checkpoint_dir)
 | 
					        self.model = VaceWanModel.from_pretrained(checkpoint_dir)
 | 
				
			||||||
@ -357,11 +364,16 @@ class WanVace(WanT2V):
 | 
				
			|||||||
        seed_g.manual_seed(seed)
 | 
					        seed_g.manual_seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not self.t5_cpu:
 | 
					        if not self.t5_cpu:
 | 
				
			||||||
 | 
					            # Offload DiT to CPU first if needed to make room for T5
 | 
				
			||||||
 | 
					            if offload_model:
 | 
				
			||||||
 | 
					                self.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            self.text_encoder.model.to(self.device)
 | 
					            self.text_encoder.model.to(self.device)
 | 
				
			||||||
            context = self.text_encoder([input_prompt], self.device)
 | 
					            context = self.text_encoder([input_prompt], self.device)
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], self.device)
 | 
					            context_null = self.text_encoder([n_prompt], self.device)
 | 
				
			||||||
            if offload_model:
 | 
					            if offload_model:
 | 
				
			||||||
                self.text_encoder.model.cpu()
 | 
					                self.text_encoder.model.cpu()
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
					            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
				
			||||||
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
					            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
				
			||||||
@ -369,10 +381,23 @@ class WanVace(WanT2V):
 | 
				
			|||||||
            context_null = [t.to(self.device) for t in context_null]
 | 
					            context_null = [t.to(self.device) for t in context_null]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # vace context encode
 | 
					        # vace context encode
 | 
				
			||||||
 | 
					        # Move VAE to GPU for encoding frames and masks if it's on CPU
 | 
				
			||||||
 | 
					        if self.vae_cpu:
 | 
				
			||||||
 | 
					            self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					            # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
        z0 = self.vace_encode_frames(
 | 
					        z0 = self.vace_encode_frames(
 | 
				
			||||||
            input_frames, input_ref_images, masks=input_masks)
 | 
					            input_frames, input_ref_images, masks=input_masks)
 | 
				
			||||||
        m0 = self.vace_encode_masks(input_masks, input_ref_images)
 | 
					        m0 = self.vace_encode_masks(input_masks, input_ref_images)
 | 
				
			||||||
        z = self.vace_latent(z0, m0)
 | 
					        z = self.vace_latent(z0, m0)
 | 
				
			||||||
 | 
					        # Offload VAE back to CPU after encoding
 | 
				
			||||||
 | 
					        if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					            self.vae.model.cpu()
 | 
				
			||||||
 | 
					            self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					            self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					            self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        target_shape = list(z0[0].shape)
 | 
					        target_shape = list(z0[0].shape)
 | 
				
			||||||
        target_shape[0] = int(target_shape[0] / 2)
 | 
					        target_shape[0] = int(target_shape[0] / 2)
 | 
				
			||||||
@ -462,7 +487,21 @@ class WanVace(WanT2V):
 | 
				
			|||||||
                self.model.cpu()
 | 
					                self.model.cpu()
 | 
				
			||||||
                torch.cuda.empty_cache()
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
            if self.rank == 0:
 | 
					            if self.rank == 0:
 | 
				
			||||||
 | 
					                # Move VAE to GPU for decoding if it's on CPU
 | 
				
			||||||
 | 
					                if self.vae_cpu:
 | 
				
			||||||
 | 
					                    self.vae.model.to(self.device)
 | 
				
			||||||
 | 
					                    # Also move scale tensors to GPU
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.to(self.device)
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
                videos = self.decode_latent(x0, input_ref_images)
 | 
					                videos = self.decode_latent(x0, input_ref_images)
 | 
				
			||||||
 | 
					                # Offload VAE back to CPU after decoding to free VRAM
 | 
				
			||||||
 | 
					                if self.vae_cpu and offload_model:
 | 
				
			||||||
 | 
					                    self.vae.model.cpu()
 | 
				
			||||||
 | 
					                    self.vae.mean = self.vae.mean.cpu()
 | 
				
			||||||
 | 
					                    self.vae.std = self.vae.std.cpu()
 | 
				
			||||||
 | 
					                    self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
 | 
				
			||||||
 | 
					                    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        del noise, latents
 | 
					        del noise, latents
 | 
				
			||||||
        del sample_scheduler
 | 
					        del sample_scheduler
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								wok/go.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										2
									
								
								wok/go.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,2 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/bash
 | 
				
			||||||
 | 
					python ../generate.py  --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --prompt "happy the dwarf and sneezy the dwarf wrestle to the death at madison square garden"
 | 
				
			||||||
							
								
								
									
										5
									
								
								wok/test1_baseline.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										5
									
								
								wok/test1_baseline.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/bash
 | 
				
			||||||
 | 
					# Test 1: Baseline (no flags) - expect OOM
 | 
				
			||||||
 | 
					echo "=== TEST 1: Baseline (no VAE offloading, no T5 offloading) ==="
 | 
				
			||||||
 | 
					echo "Expected: OOM Error during pipeline initialization"
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --prompt "happy the dwarf and sneezy the dwarf wrestle to the death at madison square garden"
 | 
				
			||||||
							
								
								
									
										5
									
								
								wok/test2_vae_cpu.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										5
									
								
								wok/test2_vae_cpu.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/bash
 | 
				
			||||||
 | 
					# Test 2: VAE CPU offloading only
 | 
				
			||||||
 | 
					echo "=== TEST 2: VAE offloading enabled (--vae_cpu) ==="
 | 
				
			||||||
 | 
					echo "Expected: Success - should save 100-200MB VRAM"
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --vae_cpu --prompt "happy the dwarf and sneezy the dwarf wrestle to the death at madison square garden"
 | 
				
			||||||
							
								
								
									
										5
									
								
								wok/test3_t5_cpu.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										5
									
								
								wok/test3_t5_cpu.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/bash
 | 
				
			||||||
 | 
					# Test 3: T5 CPU offloading only
 | 
				
			||||||
 | 
					echo "=== TEST 3: T5 offloading enabled (--t5_cpu) ==="
 | 
				
			||||||
 | 
					echo "Expected: Unknown - might still OOM, depends on T5 memory footprint"
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --t5_cpu --prompt "happy the dwarf and sneezy the dwarf wrestle to the death at madison square garden"
 | 
				
			||||||
							
								
								
									
										5
									
								
								wok/test4_both.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										5
									
								
								wok/test4_both.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					#!/usr/bin/bash
 | 
				
			||||||
 | 
					# Test 4: Both VAE and T5 CPU offloading
 | 
				
			||||||
 | 
					echo "=== TEST 4: Both VAE and T5 offloading enabled (--vae_cpu --t5_cpu) ==="
 | 
				
			||||||
 | 
					echo "Expected: Success - should save 150-250MB VRAM total"
 | 
				
			||||||
 | 
					python ../generate.py --task t2v-1.3B --size 480*832 --ckpt_dir ./t2v-1.3b --offload_model True --vae_cpu --t5_cpu --prompt "happy the dwarf and sneezy the dwarf wrestle to the death at madison square garden"
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user