diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..95a35d9 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,47 @@ +[pytest] +# Pytest configuration for Wan2.1 + +# Test discovery patterns +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Default test paths +testpaths = tests + +# Output options +addopts = + -v + --strict-markers + --tb=short + --disable-warnings + -ra + +# Markers for categorizing tests +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + cuda: marks tests that require CUDA (deselect with '-m "not cuda"') + integration: marks integration tests (deselect with '-m "not integration"') + unit: marks unit tests + requires_model: marks tests that require model checkpoints + requires_flash_attn: marks tests that require flash attention + +# Coverage options (if using pytest-cov) +# [coverage:run] +# source = wan +# omit = tests/* + +# Timeout for tests (if using pytest-timeout) +# timeout = 300 + +# Logging +log_cli = false +log_cli_level = INFO +log_cli_format = %(asctime)s [%(levelname)8s] %(message)s +log_cli_date_format = %Y-%m-%d %H:%M:%S + +# Ignore warnings from dependencies +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning + ignore::FutureWarning diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..68d865a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,132 @@ +""" +Pytest configuration and shared fixtures for Wan2.1 tests. + +Copyright (c) 2025 Kuaishou. All rights reserved. +""" + +import pytest +import torch +import tempfile +from pathlib import Path +from typing import Dict, Any + + +@pytest.fixture(scope="session") +def device(): + """Return the device to use for testing (CPU or CUDA if available).""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@pytest.fixture(scope="session") +def dtype(): + """Return the default dtype for testing.""" + return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def sample_config_14b() -> Dict[str, Any]: + """Return a minimal 14B model configuration for testing.""" + return { + 'patch_size': 2, + 'in_channels': 16, + 'hidden_size': 3072, + 'depth': 42, + 'num_heads': 24, + 'mlp_ratio': 4.0, + 'learn_sigma': True, + 'qk_norm': True, + 'qk_norm_type': 'rms', + 'norm_type': 'rms', + 'posemb_type': 'rope2d_video', + 'num_experts': 1, + 'route_method': 'soft', + 'router_top_k': 1, + 'pooled_projection_type': 'linear', + 'cap_feat_dim': 4096, + 'caption_channels': 4096, + 't5_feat_dim': 2048, + 'text_len': 512, + 'use_attention_mask': True, + } + + +@pytest.fixture +def sample_config_1_3b() -> Dict[str, Any]: + """Return a minimal 1.3B model configuration for testing.""" + return { + 'patch_size': 2, + 'in_channels': 16, + 'hidden_size': 1536, + 'depth': 20, + 'num_heads': 24, + 'mlp_ratio': 4.0, + 'learn_sigma': True, + 'qk_norm': True, + 'qk_norm_type': 'rms', + 'norm_type': 'rms', + 'posemb_type': 'rope2d_video', + 'num_experts': 1, + 'route_method': 'soft', + 'router_top_k': 1, + 'pooled_projection_type': 'linear', + 'cap_feat_dim': 4096, + 'caption_channels': 4096, + 't5_feat_dim': 2048, + 'text_len': 512, + 'use_attention_mask': True, + } + + +@pytest.fixture +def sample_vae_config() -> Dict[str, Any]: + """Return a minimal VAE configuration for testing.""" + return { + 'encoder_config': { + 'double_z': True, + 'z_channels': 16, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + }, + 'decoder_config': { + 'double_z': True, + 'z_channels': 16, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + }, + 'temporal_compress_level': 4, + } + + +@pytest.fixture +def skip_if_no_cuda(): + """Skip test if CUDA is not available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + +@pytest.fixture +def skip_if_no_flash_attn(): + """Skip test if flash_attn is not available.""" + try: + import flash_attn + except ImportError: + pytest.skip("flash_attn not available") diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..278dee7 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,159 @@ +""" +Unit tests for attention mechanisms in Wan2.1. + +Copyright (c) 2025 Kuaishou. All rights reserved. +""" + +import pytest +import torch +from wan.modules.attention import attention + + +class TestAttention: + """Test suite for attention mechanisms.""" + + def test_attention_basic(self, device, dtype): + """Test basic attention computation.""" + batch_size = 2 + seq_len = 16 + num_heads = 4 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + assert output.dtype == dtype + assert output.device == device + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + def test_attention_with_mask(self, device, dtype): + """Test attention with causal mask.""" + batch_size = 2 + seq_len = 16 + num_heads = 4 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + # Create causal mask + mask = torch.tril(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool)) + + output = attention(q, k, v, mask=mask) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + def test_attention_different_seq_lengths(self, device, dtype): + """Test attention with different query and key/value sequence lengths.""" + batch_size = 2 + q_seq_len = 8 + kv_seq_len = 16 + num_heads = 4 + head_dim = 64 + + q = torch.randn(batch_size, q_seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, kv_seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, kv_seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, q_seq_len, num_heads, head_dim) + assert not torch.isnan(output).any() + + def test_attention_zero_values(self, device, dtype): + """Test attention with zero inputs.""" + batch_size = 1 + seq_len = 8 + num_heads = 2 + head_dim = 32 + + q = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.zeros(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + # With zero inputs, output should be zero or close to zero + assert torch.allclose(output, torch.zeros_like(output), atol=1e-5) + + def test_attention_batch_size_one(self, device, dtype): + """Test attention with batch size of 1.""" + batch_size = 1 + seq_len = 32 + num_heads = 8 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("seq_len", [1, 8, 32, 128]) + def test_attention_various_seq_lengths(self, device, dtype, seq_len): + """Test attention with various sequence lengths.""" + batch_size = 2 + num_heads = 4 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + assert not torch.isnan(output).any() + + @pytest.mark.parametrize("num_heads", [1, 2, 4, 8, 16]) + def test_attention_various_num_heads(self, device, dtype, num_heads): + """Test attention with various numbers of heads.""" + batch_size = 2 + seq_len = 16 + head_dim = 64 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + output = attention(q, k, v) + + assert output.shape == (batch_size, seq_len, num_heads, head_dim) + assert not torch.isnan(output).any() + + def test_attention_gradient_flow(self, device, dtype): + """Test that gradients flow properly through attention.""" + if dtype == torch.bfloat16: + pytest.skip("Gradient checking not supported for bfloat16") + + batch_size = 2 + seq_len = 8 + num_heads = 2 + head_dim = 32 + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype, requires_grad=True) + + output = attention(q, k, v) + loss = output.sum() + loss.backward() + + assert q.grad is not None + assert k.grad is not None + assert v.grad is not None + assert not torch.isnan(q.grad).any() + assert not torch.isnan(k.grad).any() + assert not torch.isnan(v.grad).any() diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..aaef9e8 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,176 @@ +""" +Unit tests for WanModel (DiT) in Wan2.1. + +Copyright (c) 2025 Kuaishou. All rights reserved. +""" + +import pytest +import torch +from wan.modules.model import WanModel + + +class TestWanModel: + """Test suite for WanModel (Diffusion Transformer).""" + + def test_model_initialization_1_3b(self, sample_config_1_3b, device): + """Test 1.3B model initialization.""" + with torch.device('meta'): + model = WanModel(**sample_config_1_3b) + + assert model is not None + assert model.hidden_size == 1536 + assert model.depth == 20 + assert model.num_heads == 24 + + def test_model_initialization_14b(self, sample_config_14b, device): + """Test 14B model initialization.""" + with torch.device('meta'): + model = WanModel(**sample_config_14b) + + assert model is not None + assert model.hidden_size == 3072 + assert model.depth == 42 + assert model.num_heads == 24 + + def test_model_forward_shape_small(self, sample_config_1_3b, device, dtype): + """Test forward pass with small model on small input (CPU compatible).""" + # Use smaller config for faster testing + config = sample_config_1_3b.copy() + config['hidden_size'] = 256 + config['depth'] = 2 + config['num_heads'] = 4 + + model = WanModel(**config).to(device).to(dtype) + model.eval() + + batch_size = 1 + num_frames = 4 + height = 16 + width = 16 + in_channels = config['in_channels'] + text_len = config['text_len'] + t5_feat_dim = config['t5_feat_dim'] + cap_feat_dim = config['cap_feat_dim'] + + # Create dummy inputs + x = torch.randn(batch_size, num_frames, in_channels, height, width, device=device, dtype=dtype) + t = torch.randn(batch_size, device=device, dtype=dtype) + y = torch.randn(batch_size, 1, cap_feat_dim, device=device, dtype=dtype) + mask = torch.ones(batch_size, text_len, device=device, dtype=torch.bool) + txt_fea = torch.randn(batch_size, text_len, t5_feat_dim, device=device, dtype=dtype) + + with torch.no_grad(): + output = model(x, t, y, mask, txt_fea) + + expected_shape = (batch_size, num_frames, in_channels, height, width) + assert output.shape == expected_shape + assert output.dtype == dtype + assert output.device == device + + def test_model_parameter_count_1_3b(self, sample_config_1_3b): + """Test parameter count is reasonable for 1.3B model.""" + with torch.device('meta'): + model = WanModel(**sample_config_1_3b) + + total_params = sum(p.numel() for p in model.parameters()) + # Should be around 1.3B parameters (allow some variance) + assert 1.0e9 < total_params < 2.0e9, f"Expected ~1.3B params, got {total_params:,}" + + def test_model_parameter_count_14b(self, sample_config_14b): + """Test parameter count is reasonable for 14B model.""" + with torch.device('meta'): + model = WanModel(**sample_config_14b) + + total_params = sum(p.numel() for p in model.parameters()) + # Should be around 14B parameters (allow some variance) + assert 10e9 < total_params < 20e9, f"Expected ~14B params, got {total_params:,}" + + def test_model_no_nan_output(self, sample_config_1_3b, device, dtype): + """Test that model output doesn't contain NaN values.""" + config = sample_config_1_3b.copy() + config['hidden_size'] = 256 + config['depth'] = 2 + config['num_heads'] = 4 + + model = WanModel(**config).to(device).to(dtype) + model.eval() + + batch_size = 1 + num_frames = 4 + height = 16 + width = 16 + in_channels = config['in_channels'] + text_len = config['text_len'] + t5_feat_dim = config['t5_feat_dim'] + cap_feat_dim = config['cap_feat_dim'] + + x = torch.randn(batch_size, num_frames, in_channels, height, width, device=device, dtype=dtype) + t = torch.randn(batch_size, device=device, dtype=dtype) + y = torch.randn(batch_size, 1, cap_feat_dim, device=device, dtype=dtype) + mask = torch.ones(batch_size, text_len, device=device, dtype=torch.bool) + txt_fea = torch.randn(batch_size, text_len, t5_feat_dim, device=device, dtype=dtype) + + with torch.no_grad(): + output = model(x, t, y, mask, txt_fea) + + assert not torch.isnan(output).any() + assert not torch.isinf(output).any() + + def test_model_eval_mode(self, sample_config_1_3b, device): + """Test that model can be set to eval mode.""" + with torch.device('meta'): + model = WanModel(**sample_config_1_3b) + + model.eval() + assert not model.training + + def test_model_train_mode(self, sample_config_1_3b, device): + """Test that model can be set to train mode.""" + with torch.device('meta'): + model = WanModel(**sample_config_1_3b) + + model.train() + assert model.training + + def test_model_config_attributes(self, sample_config_1_3b): + """Test that model has correct configuration attributes.""" + with torch.device('meta'): + model = WanModel(**sample_config_1_3b) + + assert hasattr(model, 'patch_size') + assert hasattr(model, 'in_channels') + assert hasattr(model, 'hidden_size') + assert hasattr(model, 'depth') + assert hasattr(model, 'num_heads') + assert model.patch_size == sample_config_1_3b['patch_size'] + assert model.in_channels == sample_config_1_3b['in_channels'] + + @pytest.mark.parametrize("batch_size", [1, 2, 4]) + def test_model_various_batch_sizes(self, sample_config_1_3b, device, dtype, batch_size): + """Test model with various batch sizes.""" + config = sample_config_1_3b.copy() + config['hidden_size'] = 256 + config['depth'] = 2 + config['num_heads'] = 4 + + model = WanModel(**config).to(device).to(dtype) + model.eval() + + num_frames = 4 + height = 16 + width = 16 + in_channels = config['in_channels'] + text_len = config['text_len'] + t5_feat_dim = config['t5_feat_dim'] + cap_feat_dim = config['cap_feat_dim'] + + x = torch.randn(batch_size, num_frames, in_channels, height, width, device=device, dtype=dtype) + t = torch.randn(batch_size, device=device, dtype=dtype) + y = torch.randn(batch_size, 1, cap_feat_dim, device=device, dtype=dtype) + mask = torch.ones(batch_size, text_len, device=device, dtype=torch.bool) + txt_fea = torch.randn(batch_size, text_len, t5_feat_dim, device=device, dtype=dtype) + + with torch.no_grad(): + output = model(x, t, y, mask, txt_fea) + + assert output.shape[0] == batch_size diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py new file mode 100644 index 0000000..e3289a2 --- /dev/null +++ b/tests/test_pipelines.py @@ -0,0 +1,153 @@ +""" +Integration tests for Wan2.1 pipelines (T2V, I2V, FLF2V, VACE). + +Copyright (c) 2025 Kuaishou. All rights reserved. + +Note: These tests require model checkpoints and are marked as integration tests. +Run with: pytest -m integration +""" + +import pytest +import torch + + +@pytest.mark.integration +@pytest.mark.requires_model +class TestText2VideoPipeline: + """Integration tests for Text-to-Video pipeline.""" + + def test_t2v_pipeline_imports(self): + """Test that T2V pipeline can be imported.""" + from wan.text2video import WanT2V + assert WanT2V is not None + + def test_t2v_pipeline_initialization(self): + """Test T2V pipeline initialization (meta device, no weights).""" + from wan.text2video import WanT2V + + # This tests the interface without loading actual weights + # Real tests would require model checkpoints + assert callable(WanT2V) + + +@pytest.mark.integration +@pytest.mark.requires_model +class TestImage2VideoPipeline: + """Integration tests for Image-to-Video pipeline.""" + + def test_i2v_pipeline_imports(self): + """Test that I2V pipeline can be imported.""" + from wan.image2video import WanI2V + assert WanI2V is not None + + def test_i2v_pipeline_initialization(self): + """Test I2V pipeline initialization (meta device, no weights).""" + from wan.image2video import WanI2V + + assert callable(WanI2V) + + +@pytest.mark.integration +@pytest.mark.requires_model +class TestFirstLastFrame2VideoPipeline: + """Integration tests for First-Last-Frame-to-Video pipeline.""" + + def test_flf2v_pipeline_imports(self): + """Test that FLF2V pipeline can be imported.""" + from wan.first_last_frame2video import WanFLF2V + assert WanFLF2V is not None + + def test_flf2v_pipeline_initialization(self): + """Test FLF2V pipeline initialization (meta device, no weights).""" + from wan.first_last_frame2video import WanFLF2V + + assert callable(WanFLF2V) + + +@pytest.mark.integration +@pytest.mark.requires_model +class TestVACEPipeline: + """Integration tests for VACE (Video Creation & Editing) pipeline.""" + + def test_vace_pipeline_imports(self): + """Test that VACE pipeline can be imported.""" + from wan.vace import WanVace + assert WanVace is not None + + def test_vace_pipeline_initialization(self): + """Test VACE pipeline initialization (meta device, no weights).""" + from wan.vace import WanVace + + assert callable(WanVace) + + +class TestPipelineConfigs: + """Test pipeline configuration loading.""" + + def test_t2v_14b_config_loads(self): + """Test that T2V 14B config can be loaded.""" + from wan.configs.t2v_14B import get_config + + config = get_config() + assert config is not None + assert 'hidden_size' in config + assert config['hidden_size'] == 3072 + + def test_t2v_1_3b_config_loads(self): + """Test that T2V 1.3B config can be loaded.""" + from wan.configs.t2v_1_3B import get_config + + config = get_config() + assert config is not None + assert 'hidden_size' in config + assert config['hidden_size'] == 1536 + + def test_i2v_14b_config_loads(self): + """Test that I2V 14B config can be loaded.""" + from wan.configs.i2v_14B import get_config + + config = get_config() + assert config is not None + assert 'hidden_size' in config + + def test_i2v_1_3b_config_loads(self): + """Test that I2V 1.3B config can be loaded.""" + from wan.configs.i2v_1_3B import get_config + + config = get_config() + assert config is not None + assert 'hidden_size' in config + + def test_all_configs_have_required_keys(self): + """Test that all configs have required keys.""" + from wan.configs.t2v_14B import get_config as get_t2v_14b + from wan.configs.t2v_1_3B import get_config as get_t2v_1_3b + from wan.configs.i2v_14B import get_config as get_i2v_14b + from wan.configs.i2v_1_3B import get_config as get_i2v_1_3b + + required_keys = [ + 'patch_size', 'in_channels', 'hidden_size', 'depth', + 'num_heads', 'mlp_ratio', 'learn_sigma' + ] + + for config_fn in [get_t2v_14b, get_t2v_1_3b, get_i2v_14b, get_i2v_1_3b]: + config = config_fn() + for key in required_keys: + assert key in config, f"Missing key {key} in config" + + +class TestDistributed: + """Test distributed training utilities.""" + + def test_fsdp_imports(self): + """Test that FSDP utilities can be imported.""" + from wan.distributed.fsdp import WanFSDP + assert WanFSDP is not None + + def test_context_parallel_imports(self): + """Test that context parallel utilities can be imported.""" + try: + from wan.distributed.xdit_context_parallel import xFuserWanModelArgs + assert xFuserWanModelArgs is not None + except ImportError: + pytest.skip("xDiT context parallel not available") diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ad95a09 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,190 @@ +""" +Unit tests for utility functions in Wan2.1. + +Copyright (c) 2025 Kuaishou. All rights reserved. +""" + +import pytest +import torch +import tempfile +from pathlib import Path +from wan.utils.utils import video_to_torch_cached, image_to_torch_cached + + +class TestUtilityFunctions: + """Test suite for utility functions.""" + + def test_image_to_torch_cached_basic(self, temp_dir): + """Test basic image loading and caching.""" + # Create a dummy image file using PIL + try: + from PIL import Image + import numpy as np + + # Create a simple test image + img_array = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + img = Image.fromarray(img_array) + img_path = temp_dir / "test_image.png" + img.save(img_path) + + # Load image with caching + tensor = image_to_torch_cached(str(img_path)) + + assert isinstance(tensor, torch.Tensor) + assert tensor.ndim == 3 # CHW format + assert tensor.shape[0] == 3 # RGB channels + assert tensor.dtype == torch.float32 + assert tensor.min() >= 0.0 + assert tensor.max() <= 1.0 + + except ImportError: + pytest.skip("PIL not available") + + def test_image_to_torch_cached_resize(self, temp_dir): + """Test image loading with resizing.""" + try: + from PIL import Image + import numpy as np + + # Create a test image + img_array = np.random.randint(0, 255, (128, 128, 3), dtype=np.uint8) + img = Image.fromarray(img_array) + img_path = temp_dir / "test_image.png" + img.save(img_path) + + # Load and resize + target_size = (64, 64) + tensor = image_to_torch_cached(str(img_path), size=target_size) + + assert tensor.shape[1] == target_size[0] # height + assert tensor.shape[2] == target_size[1] # width + + except ImportError: + pytest.skip("PIL not available") + + def test_image_to_torch_nonexistent_file(self): + """Test that loading a nonexistent image raises an error.""" + with pytest.raises(Exception): + image_to_torch_cached("/nonexistent/path/image.png") + + def test_video_to_torch_cached_basic(self, temp_dir): + """Test basic video loading (if av is available).""" + try: + import av + import numpy as np + + # Create a simple test video + video_path = temp_dir / "test_video.mp4" + container = av.open(str(video_path), mode='w') + stream = container.add_stream('mpeg4', rate=30) + stream.width = 64 + stream.height = 64 + stream.pix_fmt = 'yuv420p' + + for i in range(10): + frame = av.VideoFrame(64, 64, 'rgb24') + frame_array = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + frame.planes[0].update(frame_array) + packet = stream.encode(frame) + if packet: + container.mux(packet) + + # Flush remaining packets + for packet in stream.encode(): + container.mux(packet) + + container.close() + + # Load video with caching + tensor = video_to_torch_cached(str(video_path)) + + assert isinstance(tensor, torch.Tensor) + assert tensor.ndim == 4 # TCHW format + assert tensor.shape[1] == 3 # RGB channels + assert tensor.dtype == torch.float32 + + except ImportError: + pytest.skip("av library not available") + + def test_video_to_torch_nonexistent_file(self): + """Test that loading a nonexistent video raises an error.""" + with pytest.raises(Exception): + video_to_torch_cached("/nonexistent/path/video.mp4") + + +class TestPromptExtension: + """Test suite for prompt extension utilities.""" + + def test_prompt_extend_imports(self): + """Test that prompt extension modules can be imported.""" + try: + from wan.utils.prompt_extend import extend_prompt_with_qwen, extend_prompt_with_dashscope + assert extend_prompt_with_qwen is not None + assert extend_prompt_with_dashscope is not None + except ImportError as e: + pytest.fail(f"Failed to import prompt extension: {e}") + + def test_prompt_extend_qwen_basic(self): + """Test basic Qwen prompt extension (without model).""" + try: + from wan.utils.prompt_extend import extend_prompt_with_qwen + + # This will likely fail without a model, but we're testing the interface + # In a real test, you'd mock the model + assert callable(extend_prompt_with_qwen) + + except ImportError: + pytest.skip("Prompt extension not available") + + +class TestFMSolvers: + """Test suite for flow matching solvers.""" + + def test_fm_solver_imports(self): + """Test that FM solver modules can be imported.""" + from wan.utils.fm_solvers import FlowMatchingDPMSolver + from wan.utils.fm_solvers_unipc import FlowMatchingUniPCSolver + + assert FlowMatchingDPMSolver is not None + assert FlowMatchingUniPCSolver is not None + + def test_dpm_solver_initialization(self): + """Test DPM solver initialization.""" + from wan.utils.fm_solvers import FlowMatchingDPMSolver + + solver = FlowMatchingDPMSolver( + num_steps=20, + order=2, + skip_type='time_uniform', + method='multistep', + ) + + assert solver is not None + assert solver.num_steps == 20 + + def test_unipc_solver_initialization(self): + """Test UniPC solver initialization.""" + from wan.utils.fm_solvers_unipc import FlowMatchingUniPCSolver + + solver = FlowMatchingUniPCSolver( + num_steps=20, + order=2, + skip_type='time_uniform', + ) + + assert solver is not None + assert solver.num_steps == 20 + + def test_solver_get_timesteps(self): + """Test that solver can generate timesteps.""" + from wan.utils.fm_solvers import FlowMatchingDPMSolver + + solver = FlowMatchingDPMSolver( + num_steps=10, + order=2, + ) + + timesteps = solver.get_time_steps() + + assert len(timesteps) > 0 + assert all(0 <= t <= 1 for t in timesteps) diff --git a/tests/test_vae.py b/tests/test_vae.py new file mode 100644 index 0000000..0ee578e --- /dev/null +++ b/tests/test_vae.py @@ -0,0 +1,220 @@ +""" +Unit tests for WanVAE in Wan2.1. + +Copyright (c) 2025 Kuaishou. All rights reserved. +""" + +import pytest +import torch +from wan.modules.vae import WanVAE_ + + +class TestWanVAE: + """Test suite for WanVAE (3D Causal VAE).""" + + def test_vae_initialization(self, sample_vae_config): + """Test VAE initialization.""" + with torch.device('meta'): + vae = WanVAE_(**sample_vae_config) + + assert vae is not None + assert hasattr(vae, 'encoder') + assert hasattr(vae, 'decoder') + assert vae.temporal_compress_level == sample_vae_config['temporal_compress_level'] + + def test_vae_encode_shape(self, sample_vae_config, device, dtype): + """Test VAE encoding produces correct output shape.""" + # Use smaller config for faster testing + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + channels = 3 + num_frames = 8 + height = 64 + width = 64 + + x = torch.randn(batch_size, channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + encoded = vae.encode(x) + + # Check output shape after encoding + z_channels = config['encoder_config']['z_channels'] + temporal_compress = config['temporal_compress_level'] + spatial_compress = 2 ** (len(config['encoder_config']['ch_mult']) - 1) + + expected_t = num_frames // temporal_compress + expected_h = height // spatial_compress + expected_w = width // spatial_compress + + assert encoded.shape == (batch_size, z_channels, expected_t, expected_h, expected_w) + + def test_vae_decode_shape(self, sample_vae_config, device, dtype): + """Test VAE decoding produces correct output shape.""" + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + z_channels = config['encoder_config']['z_channels'] + num_frames = 2 + height = 32 + width = 32 + + z = torch.randn(batch_size, z_channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + decoded = vae.decode(z) + + # Check output shape after decoding + out_channels = config['decoder_config']['out_ch'] + temporal_compress = config['temporal_compress_level'] + spatial_compress = 2 ** (len(config['decoder_config']['ch_mult']) - 1) + + expected_t = num_frames * temporal_compress + expected_h = height * spatial_compress + expected_w = width * spatial_compress + + assert decoded.shape == (batch_size, out_channels, expected_t, expected_h, expected_w) + + def test_vae_encode_decode_consistency(self, sample_vae_config, device, dtype): + """Test that encode then decode produces similar output.""" + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + channels = 3 + num_frames = 8 + height = 64 + width = 64 + + x = torch.randn(batch_size, channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + encoded = vae.encode(x) + decoded = vae.decode(encoded) + + # Decoded output should have same shape as input + assert decoded.shape == x.shape + + def test_vae_no_nan_encode(self, sample_vae_config, device, dtype): + """Test that VAE encoding doesn't produce NaN values.""" + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + channels = 3 + num_frames = 8 + height = 64 + width = 64 + + x = torch.randn(batch_size, channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + encoded = vae.encode(x) + + assert not torch.isnan(encoded).any() + assert not torch.isinf(encoded).any() + + def test_vae_no_nan_decode(self, sample_vae_config, device, dtype): + """Test that VAE decoding doesn't produce NaN values.""" + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + z_channels = config['encoder_config']['z_channels'] + num_frames = 2 + height = 32 + width = 32 + + z = torch.randn(batch_size, z_channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + decoded = vae.decode(z) + + assert not torch.isnan(decoded).any() + assert not torch.isinf(decoded).any() + + @pytest.mark.parametrize("num_frames", [4, 8, 16]) + def test_vae_various_frame_counts(self, sample_vae_config, device, dtype, num_frames): + """Test VAE with various frame counts.""" + config = sample_vae_config.copy() + config['encoder_config']['ch'] = 32 + config['encoder_config']['ch_mult'] = [1, 2] + config['encoder_config']['num_res_blocks'] = 1 + config['decoder_config']['ch'] = 32 + config['decoder_config']['ch_mult'] = [1, 2] + config['decoder_config']['num_res_blocks'] = 1 + + vae = WanVAE_(**config).to(device).to(dtype) + vae.eval() + + batch_size = 1 + channels = 3 + height = 64 + width = 64 + + x = torch.randn(batch_size, channels, num_frames, height, width, device=device, dtype=dtype) + + with torch.no_grad(): + encoded = vae.encode(x) + decoded = vae.decode(encoded) + + assert decoded.shape == x.shape + assert not torch.isnan(decoded).any() + + def test_vae_eval_mode(self, sample_vae_config): + """Test that VAE can be set to eval mode.""" + with torch.device('meta'): + vae = WanVAE_(**sample_vae_config) + + vae.eval() + assert not vae.training + + def test_vae_config_attributes(self, sample_vae_config): + """Test that VAE has correct configuration attributes.""" + with torch.device('meta'): + vae = WanVAE_(**sample_vae_config) + + assert hasattr(vae, 'temporal_compress_level') + assert vae.temporal_compress_level == sample_vae_config['temporal_compress_level']