diff --git a/README.md b/README.md index 52799c5..245e4ff 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,24 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## πŸ”₯ Latest Updates +### July 6 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : +**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk an for very a long time (which **Infinite** amount of time in the field of video generation). + +Of course you will get as well plain *Multitalk* vanilla and also *Multitalk 720p* as a bonus. + +And since I am mister nice guy I had enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. + +As I feel like a resting a bit I haven't produced a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your Master Pieces. The best one will be added to the *Announcements Channel* and will bring eternal fame to its author. + +But wait, there is more: +- Sliding Windows support has been added anywhere with Wan models, so imagine now with text2video upgraded in 6.5 into a video2video, you can upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) +- I have added also the capability to transfer the audio of the original control video and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio and the original pace. Be aware that the duration will be limited 1000 frames as I still need to add streaming support for unlimited video sizes. + +Also, of interest too: +- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos +- Force the generated video fps to your liking, works wery well with Vace when using a Control Video +- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) + ### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: - View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations - In one click use the newly generated video as a Control Video or Source Video to be continued diff --git a/configs/flf2v_720p.json b/configs/flf2v_720p.json index f5a12b2..2ec6691 100644 --- a/configs/flf2v_720p.json +++ b/configs/flf2v_720p.json @@ -10,5 +10,6 @@ "num_heads": 40, "num_layers": 40, "out_dim": 16, - "text_len": 512 + "text_len": 512, + "flf": true } diff --git a/configs/multitalk.json b/configs/multitalk.json new file mode 100644 index 0000000..2724759 --- /dev/null +++ b/configs/multitalk.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "multitalk_output_dim": 768 +} diff --git a/configs/vace_multitalk_14B.json b/configs/vace_multitalk_14B.json new file mode 100644 index 0000000..17a9615 --- /dev/null +++ b/configs/vace_multitalk_14B.json @@ -0,0 +1,17 @@ +{ + "_class_name": "VaceWanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_dim": 96, + "multitalk_output_dim": 768 +} diff --git a/finetunes/fantasy.json b/finetunes/fantasy.json new file mode 100644 index 0000000..dbab1b2 --- /dev/null +++ b/finetunes/fantasy.json @@ -0,0 +1,12 @@ +{ + "model": + { + "name": "Fantasy Talking 720p", + "architecture" : "fantasy", + "modules": ["fantasy"], + "description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.", + "URLs": "i2v_720p", + "teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + }, + "resolution": "1280x720" +} diff --git a/finetunes/flf2v_720p.json b/finetunes/flf2v_720p.json new file mode 100644 index 0000000..88b5387 --- /dev/null +++ b/finetunes/flf2v_720p.json @@ -0,0 +1,16 @@ +{ + "model": + { + "name": "First Last Frame to Video 720p (FLF2V)14B", + "architecture" : "flf2v_720p", + "visible" : false, + "description": "The First Last Frame 2 Video model is the official model Image 2 Video model that supports Start and End frames.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors" + ], + "auto_quantize": true + }, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/finetunes/moviigen.json b/finetunes/moviigen.json new file mode 100644 index 0000000..96a04f8 --- /dev/null +++ b/finetunes/moviigen.json @@ -0,0 +1,16 @@ +{ + "model": + { + "name": "MoviiGen 1080p 14B", + "architecture" : "t2v", + "description": "MoviiGen 1.1, a cutting-edge video generation model that excels in cinematic aesthetics and visual quality. Use it to generate videos in 720p or 1080p in the 21:9 ratio.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_quanto_mfp16_int8.safetensors" + ], + "auto_quantize": true + }, + "resolution": "1280x720", + "video_length": 81 +} \ No newline at end of file diff --git a/finetunes/multitalk.json b/finetunes/multitalk.json new file mode 100644 index 0000000..9c389d5 --- /dev/null +++ b/finetunes/multitalk.json @@ -0,0 +1,11 @@ +{ + "model": + { + "name": "Multitalk 480p", + "architecture" : "multitalk", + "modules": ["multitalk"], + "description": "The Multitalk model corresponds to the original Wan image 2 video model combined with the Multitalk module. It lets you have up to two people have a conversation.", + "URLs": "i2v", + "teacache_coefficients" : [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + } +} \ No newline at end of file diff --git a/finetunes/multitalk_720p.json b/finetunes/multitalk_720p.json new file mode 100644 index 0000000..4bdaabc --- /dev/null +++ b/finetunes/multitalk_720p.json @@ -0,0 +1,13 @@ +{ + "model": + { + "name": "Multitalk 720p", + "architecture" : "multitalk", + "modules": ["multitalk"], + "description": "The Multitalk model corresponds to the original Wan image 2 video 720p model combined with the Multitalk module. It lets you have up to two people have a conversation.", + "URLs": "i2v_720p", + "teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], + "auto_quantize": true + }, + "resolution": "1280x720" +} diff --git a/finetunes/vace_14B.json b/finetunes/vace_14B.json new file mode 100644 index 0000000..139bad4 --- /dev/null +++ b/finetunes/vace_14B.json @@ -0,0 +1,11 @@ +{ + "model": { + "name": "Vace ControlNet 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", + "URLs": "t2v" + } +} \ No newline at end of file diff --git a/finetunes/vace_14B_fusionix.json b/finetunes/vace_14B_fusionix.json index c73cd3d..99b07d1 100644 --- a/finetunes/vace_14B_fusionix.json +++ b/finetunes/vace_14B_fusionix.json @@ -6,12 +6,7 @@ "vace_14B" ], "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", - "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors" - ], - "auto_quantize": true + "URLs": "t2v_fusionix" }, "negative_prompt": "", "prompt": "", diff --git a/finetunes/vace_multitalk_14B.json b/finetunes/vace_multitalk_14B.json new file mode 100644 index 0000000..c35a048 --- /dev/null +++ b/finetunes/vace_multitalk_14B.json @@ -0,0 +1,41 @@ +{ + "model": { + "name": "Vace Multitalk FusioniX 14B", + "architecture": "vace_multitalk_14B", + "modules": [ + "vace_14B", + "multitalk" + ], + "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. And it that's not sufficient Vace is combined with Multitalk.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors" + ], + "auto_quantize": true + }, + "negative_prompt": "", + "prompt": "", + "resolution": "832x480", + "video_length": 81, + "seed": -1, + "num_inference_steps": 10, + "guidance_scale": 1, + "flow_shift": 5, + "embedded_guidance_scale": 6, + "repeat_generation": 1, + "multi_images_gen_type": 0, + "tea_cache_setting": 0, + "tea_cache_start_step_perc": 0, + "loras_multipliers": "", + "temporal_upsampling": "", + "spatial_upsampling": "", + "RIFLEx_setting": 0, + "slg_switch": 0, + "slg_start_perc": 10, + "slg_end_perc": 90, + "cfg_star_switch": 0, + "cfg_zero_step": -1, + "prompt_enhancer": "", + "activated_loras": [] +} \ No newline at end of file diff --git a/postprocessing/mmaudio/data/av_utils.py b/postprocessing/mmaudio/data/av_utils.py index 9780ab0..6fd0b1d 100644 --- a/postprocessing/mmaudio/data/av_utils.py +++ b/postprocessing/mmaudio/data/av_utils.py @@ -131,24 +131,14 @@ from pathlib import Path import torch def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int): - """Remux video with new audio using FFmpeg.""" + from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = Path(f.name) - - try: - # Write audio as WAV - import torchaudio - torchaudio.save(str(temp_path), audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate) - - # Remux with FFmpeg - subprocess.run([ - 'ffmpeg', '-i', str(video_path), '-i', str(temp_path), - '-c:v', 'copy', '-c:a', 'aac', '-map', '0:v', '-map', '1:a', - '-shortest', '-y', str(output_path) - ], check=True, capture_output=True) - - finally: - temp_path.unlink(missing_ok=True) + temp_path_str= str(temp_path) + import torchaudio + torchaudio.save(temp_path_str, audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate) + combine_video_with_audio_tracks(video_path, [temp_path_str], output_path ) + temp_path.unlink(missing_ok=True) def remux_with_audio_old(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int): """ diff --git a/postprocessing/mmaudio/utils/logger.py b/postprocessing/mmaudio/utils/logger.py index bd8cea2..6a170c6 100644 --- a/postprocessing/mmaudio/utils/logger.py +++ b/postprocessing/mmaudio/utils/logger.py @@ -9,7 +9,8 @@ import os from collections import defaultdict from pathlib import Path from typing import Optional, Union - +import matplotlib +matplotlib.use('TkAgg') import matplotlib.pyplot as plt import numpy as np import torch diff --git a/preprocessing/matanyone/tools/interact_tools.py b/preprocessing/matanyone/tools/interact_tools.py index c70b8c4..c5d39b6 100644 --- a/preprocessing/matanyone/tools/interact_tools.py +++ b/preprocessing/matanyone/tools/interact_tools.py @@ -5,6 +5,8 @@ from PIL import Image, ImageDraw, ImageOps import numpy as np from typing import Union from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib +matplotlib.use('TkAgg') import matplotlib.pyplot as plt import PIL from .mask_painter import mask_painter as mask_painter2 diff --git a/preprocessing/speakers_separator.py b/preprocessing/speakers_separator.py new file mode 100644 index 0000000..5ab911a --- /dev/null +++ b/preprocessing/speakers_separator.py @@ -0,0 +1,922 @@ +import torch +import torchaudio +import numpy as np +import os +import warnings +from pathlib import Path +from typing import Dict, List, Tuple +import argparse +from concurrent.futures import ThreadPoolExecutor +import gc +import logging + +verbose_output = True + +# Suppress specific warnings before importing pyannote +warnings.filterwarnings("ignore", category=UserWarning, module="pyannote.audio.models.blocks.pooling") +warnings.filterwarnings("ignore", message=".*TensorFloat-32.*", category=UserWarning) +warnings.filterwarnings("ignore", message=".*std\\(\\): degrees of freedom.*", category=UserWarning) +warnings.filterwarnings("ignore", message=".*speechbrain.pretrained.*was deprecated.*", category=UserWarning) +warnings.filterwarnings("ignore", message=".*Module 'speechbrain.pretrained'.*", category=UserWarning) +# logging.getLogger('speechbrain').setLevel(logging.WARNING) +# logging.getLogger('speechbrain.utils.checkpoints').setLevel(logging.WARNING) +os.environ["SB_LOG_LEVEL"] = "WARNING" +import speechbrain + +def xprint(t = None): + if verbose_output: + print(t) + +# Configure TF32 before any CUDA operations to avoid reproducibility warnings +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + +try: + from pyannote.audio import Pipeline + PYANNOTE_AVAILABLE = True +except ImportError: + PYANNOTE_AVAILABLE = False + print("Install: pip install pyannote.audio") + + +class OptimizedPyannote31SpeakerSeparator: + def __init__(self, hf_token: str = None, local_model_path: str = None, + vad_onset: float = 0.2, vad_offset: float = 0.8): + """ + Initialize with Pyannote 3.1 pipeline with tunable VAD sensitivity. + """ + embedding_path = "ckpts/pyannote/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin" + segmentation_path = "ckpts/pyannote/pytorch_model_segmentation-3.0.bin" + + + xprint(f"Loading segmentation model from: {segmentation_path}") + xprint(f"Loading embedding model from: {embedding_path}") + + try: + from pyannote.audio import Model + from pyannote.audio.pipelines import SpeakerDiarization + + # Load models directly + segmentation_model = Model.from_pretrained(segmentation_path) + embedding_model = Model.from_pretrained(embedding_path) + xprint("Models loaded successfully!") + + # Create pipeline manually + self.pipeline = SpeakerDiarization( + segmentation=segmentation_model, + embedding=embedding_model, + clustering='AgglomerativeClustering' + ) + + # Instantiate with default parameters + self.pipeline.instantiate({ + 'clustering': { + 'method': 'centroid', + 'min_cluster_size': 12, + 'threshold': 0.7045654963945799 + }, + 'segmentation': { + 'min_duration_off': 0.0 + } + }) + xprint("Pipeline instantiated successfully!") + + # Send to GPU if available + if torch.cuda.is_available(): + xprint("CUDA available, moving pipeline to GPU...") + self.pipeline.to(torch.device("cuda")) + else: + xprint("CUDA not available, using CPU...") + + except Exception as e: + xprint(f"Error loading pipeline: {e}") + xprint(f"Error type: {type(e)}") + import traceback + traceback.xprint_exc() + raise + + + self.hf_token = hf_token + self._overlap_pipeline = None + + def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: + """Optimized main separation function with memory management.""" + xprint("Starting optimized audio separation...") + self._current_audio_path = os.path.abspath(audio_path) + + # Suppress warnings during processing + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + + # Load audio + waveform, sample_rate = self.load_audio(audio_path) + + # Perform diarization + diarization = self.perform_optimized_diarization(audio_path) + + # Create masks + masks = self.create_optimized_speaker_masks(diarization, waveform.shape[1], sample_rate) + + # Apply background preservation + final_masks = self.apply_optimized_background_preservation(masks, waveform.shape[1]) + + # Clear intermediate results + del masks + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # Save outputs efficiently + output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2) + + return output_paths + + def _extract_both_speaking_regions( + self, + diarization, + audio_length: int, + sample_rate: int + ) -> np.ndarray: + """ + Detect regions where β‰₯2 speakers talk simultaneously + using pyannote/overlapped-speech-detection. + Falls back to manual pair-wise detection if the model + is unavailable. + """ + xprint("Extracting overlap with dedicated pipeline…") + both_speaking_mask = np.zeros(audio_length, dtype=bool) + + # ── 1) try the proper overlap model ──────────────────────────────── + overlap_pipeline = self._get_overlap_pipeline() + + # try the path stored by separate_audio – otherwise whatever the + # diarization object carries (may be None) + audio_uri = getattr(self, "_current_audio_path", None) \ + or getattr(diarization, "uri", None) + if overlap_pipeline and audio_uri: + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + overlap_annotation = overlap_pipeline(audio_uri) + + for seg in overlap_annotation.get_timeline().support(): + s = max(0, int(seg.start * sample_rate)) + e = min(audio_length, int(seg.end * sample_rate)) + if s < e: + both_speaking_mask[s:e] = True + t = np.sum(both_speaking_mask) / sample_rate + xprint(f" Found {t:.1f}s of overlapped speech (model) ") + return both_speaking_mask + except Exception as e: + xprint(f" ⚠ Overlap model failed: {e}") + + # ── 2) fallback = brute-force pairwise intersection ──────────────── + xprint(" Falling back to manual overlap detection…") + timeline_tracks = list(diarization.itertracks(yield_label=True)) + for i, (turn1, _, spk1) in enumerate(timeline_tracks): + for j, (turn2, _, spk2) in enumerate(timeline_tracks): + if i >= j or spk1 == spk2: + continue + o_start, o_end = max(turn1.start, turn2.start), min(turn1.end, turn2.end) + if o_start < o_end: + s = max(0, int(o_start * sample_rate)) + e = min(audio_length, int(o_end * sample_rate)) + if s < e: + both_speaking_mask[s:e] = True + t = np.sum(both_speaking_mask) / sample_rate + xprint(f" Found {t:.1f}s of overlapped speech (manual) ") + return both_speaking_mask + + def _configure_vad(self, vad_onset: float, vad_offset: float): + """Configure VAD parameters efficiently.""" + xprint("Applying more sensitive VAD parameters...") + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + + if hasattr(self.pipeline, '_vad'): + self.pipeline._vad.instantiate({ + "onset": vad_onset, + "offset": vad_offset, + "min_duration_on": 0.1, + "min_duration_off": 0.1, + "pad_onset": 0.1, + "pad_offset": 0.1, + }) + xprint(f"βœ“ VAD parameters updated: onset={vad_onset}, offset={vad_offset}") + else: + xprint("⚠ Could not access VAD component directly") + except Exception as e: + xprint(f"⚠ Could not modify VAD parameters: {e}") + + def _get_overlap_pipeline(self): + """ + Build a pyannote-3-native OverlappedSpeechDetection pipeline. + + β€’ uses the open-licence `pyannote/segmentation-3.0` checkpoint + β€’ only `min_duration_on/off` can be tuned (API 3.x) + """ + if self._overlap_pipeline is not None: + return None if self._overlap_pipeline is False else self._overlap_pipeline + + try: + from pyannote.audio.pipelines import OverlappedSpeechDetection + + xprint("Building OverlappedSpeechDetection with segmentation-3.0…") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + + # 1) constructor β†’ segmentation model ONLY + ods = OverlappedSpeechDetection( + segmentation="pyannote/segmentation-3.0" + ) + + # 2) instantiate β†’ **single dict** with the two valid knobs + ods.instantiate({ + "min_duration_on": 0.06, # β‰ˆ your previous 0.055 s + "min_duration_off": 0.10, # β‰ˆ your previous 0.098 s + }) + + if torch.cuda.is_available(): + ods.to(torch.device("cuda")) + + self._overlap_pipeline = ods + xprint("βœ“ Overlap pipeline ready (segmentation-3.0)") + return ods + + except Exception as e: + xprint(f"⚠ Could not build overlap pipeline ({e}). " + "Falling back to manual pair-wise detection.") + self._overlap_pipeline = False + return None + + def _xprint_setup_instructions(self): + """xprint setup instructions.""" + xprint("\nTo use Pyannote 3.1:") + xprint("1. Get token: https://huggingface.co/settings/tokens") + xprint("2. Accept terms: https://huggingface.co/pyannote/speaker-diarization-3.1") + xprint("3. Run with: --token YOUR_TOKEN") + + def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]: + """Load and preprocess audio efficiently.""" + xprint(f"Loading audio: {audio_path}") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + waveform, sample_rate = torchaudio.load(audio_path) + + # Convert to mono efficiently + if waveform.shape[0] > 1: + waveform = torch.mean(waveform, dim=0, keepdim=True) + + xprint(f"Audio: {waveform.shape[1]} samples at {sample_rate}Hz") + return waveform, sample_rate + + def perform_optimized_diarization(self, audio_path: str) -> object: + """ + Optimized diarization with efficient parameter testing. + """ + xprint("Running optimized Pyannote 3.1 diarization...") + + # Optimized strategy order - most likely to succeed first + strategies = [ + {"min_speakers": 2, "max_speakers": 2}, # Most common case + {"num_speakers": 2}, # Direct specification + {"min_speakers": 2, "max_speakers": 3}, # Slight flexibility + {"min_speakers": 1, "max_speakers": 2}, # Fallback + {"min_speakers": 2, "max_speakers": 4}, # More flexibility + {} # No constraints + ] + + for i, params in enumerate(strategies): + try: + xprint(f"Strategy {i+1}: {params}") + + # Clear GPU memory before each attempt + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + diarization = self.pipeline(audio_path, **params) + + speakers = list(diarization.labels()) + speaker_count = len(speakers) + + xprint(f" β†’ Detected {speaker_count} speakers: {speakers}") + + # Accept first successful result with 2+ speakers + if speaker_count >= 2: + xprint(f"βœ“ Success with strategy {i+1}! Using {speaker_count} speakers") + return diarization + elif speaker_count == 1 and i == 0: + # Store first result as fallback + fallback_diarization = diarization + + except Exception as e: + xprint(f" Strategy {i+1} failed: {e}") + continue + + # If we only got 1 speaker, try one aggressive attempt + if 'fallback_diarization' in locals(): + xprint("Attempting aggressive clustering for single speaker...") + try: + aggressive_diarization = self._try_aggressive_clustering(audio_path) + if aggressive_diarization and len(list(aggressive_diarization.labels())) >= 2: + return aggressive_diarization + except Exception as e: + xprint(f"Aggressive clustering failed: {e}") + + xprint("Using single speaker result") + return fallback_diarization + + # Last resort - run without constraints + xprint("Last resort: running without constraints...") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + return self.pipeline(audio_path) + + def _try_aggressive_clustering(self, audio_path: str) -> object: + """Try aggressive clustering parameters.""" + try: + from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + + # Create aggressive pipeline + temp_pipeline = SpeakerDiarization( + segmentation=self.pipeline.segmentation, + embedding=self.pipeline.embedding, + clustering="AgglomerativeClustering" + ) + + temp_pipeline.instantiate({ + "clustering": { + "method": "centroid", + "min_cluster_size": 1, + "threshold": 0.1, + }, + "segmentation": { + "min_duration_off": 0.0, + "min_duration_on": 0.1, + } + }) + + return temp_pipeline(audio_path, min_speakers=2) + + except Exception as e: + xprint(f"Aggressive clustering setup failed: {e}") + return None + + def create_optimized_speaker_masks(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: + """Optimized mask creation using vectorized operations.""" + xprint("Creating optimized speaker masks...") + + speakers = list(diarization.labels()) + xprint(f"Processing speakers: {speakers}") + + # Handle edge cases + if len(speakers) == 0: + xprint("⚠ No speakers detected, creating dummy masks") + return self._create_dummy_masks(audio_length) + + if len(speakers) == 1: + xprint("⚠ Only 1 speaker detected, creating temporal split") + return self._create_optimized_temporal_split(diarization, audio_length, sample_rate) + + # Extract both-speaking regions from diarization timeline + both_speaking_regions = self._extract_both_speaking_regions(diarization, audio_length, sample_rate) + + # Optimized mask creation for multiple speakers + masks = {} + + # Batch process all speakers + for speaker in speakers: + # Get all segments for this speaker at once + segments = [] + speaker_timeline = diarization.label_timeline(speaker) + for segment in speaker_timeline: + start_sample = max(0, int(segment.start * sample_rate)) + end_sample = min(audio_length, int(segment.end * sample_rate)) + if start_sample < end_sample: + segments.append((start_sample, end_sample)) + + # Vectorized mask creation + if segments: + mask = self._create_mask_vectorized(segments, audio_length) + masks[speaker] = mask + speaking_time = np.sum(mask) / sample_rate + xprint(f" {speaker}: {speaking_time:.1f}s speaking time") + else: + masks[speaker] = np.zeros(audio_length, dtype=np.float32) + + # Store both-speaking info for later use + self._both_speaking_regions = both_speaking_regions + + return masks + + def _create_mask_vectorized(self, segments: List[Tuple[int, int]], audio_length: int) -> np.ndarray: + """Create mask using vectorized operations.""" + mask = np.zeros(audio_length, dtype=np.float32) + + if not segments: + return mask + + # Convert segments to arrays for vectorized operations + segments_array = np.array(segments) + starts = segments_array[:, 0] + ends = segments_array[:, 1] + + # Use advanced indexing for bulk assignment + for start, end in zip(starts, ends): + mask[start:end] = 1.0 + + return mask + + def _create_dummy_masks(self, audio_length: int) -> Dict[str, np.ndarray]: + """Create dummy masks for edge cases.""" + return { + "SPEAKER_00": np.ones(audio_length, dtype=np.float32) * 0.5, + "SPEAKER_01": np.ones(audio_length, dtype=np.float32) * 0.5 + } + + def _create_optimized_temporal_split(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: + """Optimized temporal split with vectorized operations.""" + xprint("Creating optimized temporal split...") + + # Extract all segments at once + segments = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + segments.append((turn.start, turn.end)) + + segments.sort() + xprint(f"Found {len(segments)} speech segments") + + if len(segments) <= 1: + # Single segment or no segments - simple split + return self._create_simple_split(audio_length) + + # Vectorized gap analysis + segment_array = np.array(segments) + gaps = segment_array[1:, 0] - segment_array[:-1, 1] # Vectorized gap calculation + + if len(gaps) > 0: + longest_gap_idx = np.argmax(gaps) + longest_gap_duration = gaps[longest_gap_idx] + + xprint(f"Longest gap: {longest_gap_duration:.1f}s after segment {longest_gap_idx+1}") + + if longest_gap_duration > 1.0: + # Split at natural break + split_point = longest_gap_idx + 1 + xprint(f"Splitting at natural break: segments 1-{split_point} vs {split_point+1}-{len(segments)}") + + return self._create_split_masks(segments, split_point, audio_length, sample_rate) + + # Fallback: alternating assignment + xprint("Using alternating assignment...") + return self._create_alternating_masks(segments, audio_length, sample_rate) + + def _create_simple_split(self, audio_length: int) -> Dict[str, np.ndarray]: + """Simple temporal split in half.""" + mid_point = audio_length // 2 + masks = { + "SPEAKER_00": np.zeros(audio_length, dtype=np.float32), + "SPEAKER_01": np.zeros(audio_length, dtype=np.float32) + } + masks["SPEAKER_00"][:mid_point] = 1.0 + masks["SPEAKER_01"][mid_point:] = 1.0 + return masks + + def _create_split_masks(self, segments: List[Tuple[float, float]], split_point: int, + audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: + """Create masks with split at specific point.""" + masks = { + "SPEAKER_00": np.zeros(audio_length, dtype=np.float32), + "SPEAKER_01": np.zeros(audio_length, dtype=np.float32) + } + + # Vectorized segment processing + for i, (start_time, end_time) in enumerate(segments): + start_sample = max(0, int(start_time * sample_rate)) + end_sample = min(audio_length, int(end_time * sample_rate)) + + if start_sample < end_sample: + speaker_key = "SPEAKER_00" if i < split_point else "SPEAKER_01" + masks[speaker_key][start_sample:end_sample] = 1.0 + + return masks + + def _create_alternating_masks(self, segments: List[Tuple[float, float]], + audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]: + """Create masks with alternating assignment.""" + masks = { + "SPEAKER_00": np.zeros(audio_length, dtype=np.float32), + "SPEAKER_01": np.zeros(audio_length, dtype=np.float32) + } + + for i, (start_time, end_time) in enumerate(segments): + start_sample = max(0, int(start_time * sample_rate)) + end_sample = min(audio_length, int(end_time * sample_rate)) + + if start_sample < end_sample: + speaker_key = f"SPEAKER_0{i % 2}" + masks[speaker_key][start_sample:end_sample] = 1.0 + + return masks + + def apply_optimized_background_preservation(self, masks: Dict[str, np.ndarray], + audio_length: int) -> Dict[str, np.ndarray]: + """ + Heavily optimized background preservation using pure vectorized operations. + """ + xprint("Applying optimized voice separation logic...") + + # Ensure exactly 2 speakers + speaker_keys = self._get_top_speakers(masks, audio_length) + + # Pre-allocate final masks + final_masks = { + speaker: np.zeros(audio_length, dtype=np.float32) + for speaker in speaker_keys + } + + # Get active masks (vectorized) + active_0 = masks.get(speaker_keys[0], np.zeros(audio_length)) > 0.5 + active_1 = masks.get(speaker_keys[1], np.zeros(audio_length)) > 0.5 + + # Vectorized mask assignment + both_active = active_0 & active_1 + only_0 = active_0 & ~active_1 + only_1 = ~active_0 & active_1 + neither = ~active_0 & ~active_1 + + # Apply assignments (all vectorized) + final_masks[speaker_keys[0]][both_active] = 1.0 + final_masks[speaker_keys[1]][both_active] = 1.0 + + final_masks[speaker_keys[0]][only_0] = 1.0 + final_masks[speaker_keys[1]][only_0] = 0.0 + + final_masks[speaker_keys[0]][only_1] = 0.0 + final_masks[speaker_keys[1]][only_1] = 1.0 + + # Handle ambiguous regions efficiently + if np.any(neither): + ambiguous_assignments = self._compute_ambiguous_assignments_vectorized( + masks, speaker_keys, neither, audio_length + ) + + # Apply ambiguous assignments + final_masks[speaker_keys[0]][neither] = (ambiguous_assignments == 0).astype(np.float32) * 0.5 + final_masks[speaker_keys[1]][neither] = (ambiguous_assignments == 1).astype(np.float32) * 0.5 + + # xprint statistics (vectorized) + sample_rate = 16000 # Assume 16kHz for timing + xprint(f" Both speaking clearly: {np.sum(both_active)/sample_rate:.1f}s") + xprint(f" {speaker_keys[0]} only: {np.sum(only_0)/sample_rate:.1f}s") + xprint(f" {speaker_keys[1]} only: {np.sum(only_1)/sample_rate:.1f}s") + xprint(f" Ambiguous (assigned): {np.sum(neither)/sample_rate:.1f}s") + + # Apply minimum duration smoothing to prevent rapid switching + final_masks = self._apply_minimum_duration_smoothing(final_masks, sample_rate) + + return final_masks + + def _get_top_speakers(self, masks: Dict[str, np.ndarray], audio_length: int) -> List[str]: + """Get top 2 speakers by speaking time.""" + speaker_keys = list(masks.keys()) + + if len(speaker_keys) > 2: + # Vectorized speaking time calculation + speaking_times = {k: np.sum(v) for k, v in masks.items()} + speaker_keys = sorted(speaking_times.keys(), key=lambda x: speaking_times[x], reverse=True)[:2] + xprint(f"Keeping top 2 speakers: {speaker_keys}") + elif len(speaker_keys) == 1: + speaker_keys.append("SPEAKER_SILENT") + + return speaker_keys + + def _compute_ambiguous_assignments_vectorized(self, masks: Dict[str, np.ndarray], + speaker_keys: List[str], + ambiguous_mask: np.ndarray, + audio_length: int) -> np.ndarray: + """Compute speaker assignments for ambiguous regions using vectorized operations.""" + ambiguous_indices = np.where(ambiguous_mask)[0] + + if len(ambiguous_indices) == 0: + return np.array([]) + + # Get speaker segments efficiently + speaker_segments = {} + for speaker in speaker_keys: + if speaker in masks and speaker != "SPEAKER_SILENT": + mask = masks[speaker] > 0.5 + # Find segments using vectorized operations + diff = np.diff(np.concatenate(([False], mask, [False])).astype(int)) + starts = np.where(diff == 1)[0] + ends = np.where(diff == -1)[0] + speaker_segments[speaker] = np.column_stack([starts, ends]) + else: + speaker_segments[speaker] = np.array([]).reshape(0, 2) + + # Vectorized distance calculations + distances = {} + for speaker in speaker_keys: + segments = speaker_segments[speaker] + if len(segments) == 0: + distances[speaker] = np.full(len(ambiguous_indices), np.inf) + else: + # Compute distances to all segments at once + distances[speaker] = self._compute_distances_to_segments(ambiguous_indices, segments) + + # Assign based on minimum distance with late-audio bias + assignments = self._assign_based_on_distance( + distances, speaker_keys, ambiguous_indices, audio_length + ) + + return assignments + + def _apply_minimum_duration_smoothing(self, masks: Dict[str, np.ndarray], + sample_rate: int, min_duration_ms: int = 600) -> Dict[str, np.ndarray]: + """ + Apply minimum duration smoothing with STRICT timer enforcement. + Uses original both-speaking regions from diarization. + """ + xprint(f"Applying STRICT minimum duration smoothing ({min_duration_ms}ms)...") + + min_samples = int(min_duration_ms * sample_rate / 1000) + speaker_keys = list(masks.keys()) + + if len(speaker_keys) != 2: + return masks + + mask0 = masks[speaker_keys[0]] + mask1 = masks[speaker_keys[1]] + + # Use original both-speaking regions from diarization + both_speaking_original = getattr(self, '_both_speaking_regions', np.zeros(len(mask0), dtype=bool)) + + # Identify regions based on original diarization info + ambiguous_original = (mask0 < 0.3) & (mask1 < 0.3) & ~both_speaking_original + + # Clear dominance: one speaker higher, and not both-speaking or ambiguous + remaining_mask = ~both_speaking_original & ~ambiguous_original + speaker0_dominant = (mask0 > mask1) & remaining_mask + speaker1_dominant = (mask1 > mask0) & remaining_mask + + # Create preference signal including both-speaking as valid state + # -1=ambiguous, 0=speaker0, 1=speaker1, 2=both_speaking + preference_signal = np.full(len(mask0), -1, dtype=int) + preference_signal[speaker0_dominant] = 0 + preference_signal[speaker1_dominant] = 1 + preference_signal[both_speaking_original] = 2 + + # STRICT state machine enforcement + smoothed_assignment = np.full(len(mask0), -1, dtype=int) + corrections = 0 + + # State variables + current_state = -1 # -1=unset, 0=speaker0, 1=speaker1, 2=both_speaking + samples_remaining = 0 # Samples remaining in current state's lock period + + # Process each sample with STRICT enforcement + for i in range(len(preference_signal)): + preference = preference_signal[i] + + # If we're in a lock period, enforce the current state + if samples_remaining > 0: + # Force current state regardless of preference + smoothed_assignment[i] = current_state + samples_remaining -= 1 + + # Count corrections if this differs from preference + if preference >= 0 and preference != current_state: + corrections += 1 + + else: + # Lock period expired - can consider new state + + if preference >= 0: + # Clear preference available (including both-speaking) + if current_state != preference: + # Switch to new state and start new lock period + current_state = preference + samples_remaining = min_samples - 1 # -1 because we use this sample + + smoothed_assignment[i] = current_state + + else: + # Ambiguous preference + if current_state >= 0: + # Continue with current state if we have one + smoothed_assignment[i] = current_state + else: + # No current state and ambiguous - leave as ambiguous + smoothed_assignment[i] = -1 + + # Convert back to masks based on smoothed assignment + smoothed_masks = {} + + for i, speaker in enumerate(speaker_keys): + new_mask = np.zeros_like(mask0) + + # Assign regions where this speaker is dominant + speaker_regions = smoothed_assignment == i + new_mask[speaker_regions] = 1.0 + + # Assign both-speaking regions (state 2) to both speakers + both_speaking_regions = smoothed_assignment == 2 + new_mask[both_speaking_regions] = 1.0 + + # Handle ambiguous regions that remain unassigned + unassigned_ambiguous = smoothed_assignment == -1 + if np.any(unassigned_ambiguous): + # Use original ambiguous values only for truly unassigned regions + original_ambiguous_mask = ambiguous_original & unassigned_ambiguous + new_mask[original_ambiguous_mask] = masks[speaker][original_ambiguous_mask] + + smoothed_masks[speaker] = new_mask + + # Calculate and xprint statistics + both_speaking_time = np.sum(smoothed_assignment == 2) / sample_rate + speaker0_time = np.sum(smoothed_assignment == 0) / sample_rate + speaker1_time = np.sum(smoothed_assignment == 1) / sample_rate + ambiguous_time = np.sum(smoothed_assignment == -1) / sample_rate + + xprint(f" Both speaking clearly: {both_speaking_time:.1f}s") + xprint(f" {speaker_keys[0]} only: {speaker0_time:.1f}s") + xprint(f" {speaker_keys[1]} only: {speaker1_time:.1f}s") + xprint(f" Ambiguous (assigned): {ambiguous_time:.1f}s") + xprint(f" Enforced minimum duration on {corrections} samples ({corrections/sample_rate:.2f}s)") + + return smoothed_masks + + def _compute_distances_to_segments(self, indices: np.ndarray, segments: np.ndarray) -> np.ndarray: + """Compute minimum distances from indices to segments (vectorized).""" + if len(segments) == 0: + return np.full(len(indices), np.inf) + + # Broadcast for vectorized computation + indices_expanded = indices[:, np.newaxis] # Shape: (n_indices, 1) + starts = segments[:, 0] # Shape: (n_segments,) + ends = segments[:, 1] # Shape: (n_segments,) + + # Compute distances to all segments + dist_to_start = np.maximum(0, starts - indices_expanded) # Shape: (n_indices, n_segments) + dist_from_end = np.maximum(0, indices_expanded - ends) # Shape: (n_indices, n_segments) + + # Minimum of distance to start or from end for each segment + distances = np.minimum(dist_to_start, dist_from_end) + + # Return minimum distance to any segment for each index + return np.min(distances, axis=1) + + def _assign_based_on_distance(self, distances: Dict[str, np.ndarray], + speaker_keys: List[str], + ambiguous_indices: np.ndarray, + audio_length: int) -> np.ndarray: + """Assign speakers based on distance with late-audio bias.""" + speaker_0_distances = distances[speaker_keys[0]] + speaker_1_distances = distances[speaker_keys[1]] + + # Basic assignment by minimum distance + assignments = (speaker_1_distances < speaker_0_distances).astype(int) + + # Apply late-audio bias (vectorized) + late_threshold = int(audio_length * 0.6) + late_indices = ambiguous_indices > late_threshold + + if np.any(late_indices) and len(speaker_keys) > 1: + # Simple late-audio bias: prefer speaker 1 in later parts + assignments[late_indices] = 1 + + return assignments + + def _save_outputs_optimized(self, waveform: torch.Tensor, masks: Dict[str, np.ndarray], + sample_rate: int, audio_path: str, output1, output2) -> Dict[str, str]: + """Optimized output saving with parallel processing.""" + output_paths = {} + + def save_speaker_audio(speaker_mask_pair, output): + speaker, mask = speaker_mask_pair + # Convert mask to tensor efficiently + mask_tensor = torch.from_numpy(mask).unsqueeze(0) + + # Apply mask + masked_audio = waveform * mask_tensor + + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + torchaudio.save(output, masked_audio, sample_rate) + + xprint(f"βœ“ Saved {speaker}: {output}") + return speaker, output + + # Use ThreadPoolExecutor for parallel saving + with ThreadPoolExecutor(max_workers=2) as executor: + results = list(executor.map(save_speaker_audio, masks.items(), [output1, output2])) + + output_paths = dict(results) + return output_paths + + def print_summary(self, audio_path: str): + """xprint diarization summary.""" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + diarization = self.perform_optimized_diarization(audio_path) + + xprint("\n=== Diarization Summary ===") + for turn, _, speaker in diarization.itertracks(yield_label=True): + xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s") + +def extract_dual_audio(audio, output1, output2, verbose = False): + global verbose_output + verbose_output = verbose + separator = OptimizedPyannote31SpeakerSeparator( + None, + None, + vad_onset=0.2, + vad_offset=0.8 + ) + # Separate audio + import time + start_time = time.time() + + outputs = separator.separate_audio(audio, output1, output2) + + elapsed_time = time.time() - start_time + xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") + for speaker, path in outputs.items(): + xprint(f"{speaker}: {path}") + +def main(): + + parser = argparse.ArgumentParser(description="Optimized Pyannote 3.1 Speaker Separator") + parser.add_argument("--audio", required=True, help="Input audio file") + parser.add_argument("--output", required=True, help="Output directory") + parser.add_argument("--token", help="Hugging Face token") + parser.add_argument("--local-model", help="Path to local 3.1 model") + parser.add_argument("--summary", action="store_true", help="xprint summary") + + # VAD sensitivity parameters + parser.add_argument("--vad-onset", type=float, default=0.2, + help="VAD onset threshold (lower = more sensitive to speech start, default: 0.2)") + parser.add_argument("--vad-offset", type=float, default=0.8, + help="VAD offset threshold (higher = keeps speech longer, default: 0.8)") + + args = parser.parse_args() + + xprint("=== Optimized Pyannote 3.1 Speaker Separator ===") + xprint("Performance optimizations: vectorized operations, memory management, parallel processing") + xprint(f"Audio: {args.audio}") + xprint(f"Output: {args.output}") + xprint(f"VAD onset: {args.vad_onset}") + xprint(f"VAD offset: {args.vad_offset}") + xprint() + + if not os.path.exists(args.audio): + xprint(f"ERROR: Audio file not found: {args.audio}") + return + + try: + # Initialize with VAD parameters + separator = OptimizedPyannote31SpeakerSeparator( + args.token, + args.local_model, + vad_onset=args.vad_onset, + vad_offset=args.vad_offset + ) + + # print summary if requested + if args.summary: + separator.print_summary(args.audio) + + # Separate audio + import time + start_time = time.time() + + audio_name = Path(args.audio).stem + output_filename = f"{audio_name}_speaker0.wav" + output_filename1 = f"{audio_name}_speaker1.wav" + output_path = os.path.join(args.output, output_filename) + output_path1 = os.path.join(args.output, output_filename1) + + outputs = separator.separate_audio(args.audio, output_path, output_path1) + + elapsed_time = time.time() - start_time + xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") + for speaker, path in outputs.items(): + xprint(f"{speaker}: {path}") + + except Exception as e: + xprint(f"ERROR: {e}") + return 1 + + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8f9be07..5f4429d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -38,7 +38,10 @@ pygame>=2.1.0 sounddevice>=0.4.0 # rembg==2.0.65 torchdiffeq >= 0.2.5 -# 'nitrous-ema', -# 'hydra_colorlog', tensordict >= 0.6.1 -open_clip_torch >= 2.29.0 \ No newline at end of file +open_clip_torch >= 2.29.0 +pyloudnorm +misaki +soundfile +# num2words +# spacy \ No newline at end of file diff --git a/wan/__init__.py b/wan/__init__.py index 54004dd..1688425 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -1,4 +1,3 @@ from . import configs, distributed, modules -from .image2video import WanI2V -from .text2video import WanT2V +from .any2video import WanAny2V from .diffusion_forcing import DTT2V \ No newline at end of file diff --git a/wan/text2video.py b/wan/any2video.py similarity index 55% rename from wan/text2video.py rename to wan/any2video.py index a8373f7..e9a7026 100644 --- a/wan/text2video.py +++ b/wan/any2video.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn import torch.cuda.amp as amp import torch.distributed as dist +import numpy as np from tqdm import tqdm from PIL import Image import torchvision.transforms.functional as TF @@ -21,14 +22,15 @@ from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE +from .modules.clip import CLIPModel from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan.modules.posemb_layers import get_rotary_pos_embed from .utils.vace_preprocessor import VaceVideoProcessor from wan.utils.basic_flowmatch import FlowMatchScheduler -from wan.utils.utils import get_outpainting_frame_location -from wgp import update_loras_slists +from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions +from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance def optimized_scale(positive_flat, negative_flat): @@ -43,14 +45,20 @@ def optimized_scale(positive_flat, negative_flat): return st_star - -class WanT2V: +def timestep_transform(t, shift=5.0, num_timesteps=1000 ): + t = t / num_timesteps + # shift the timestep based on ratio + new_t = shift * t / (1 + (shift - 1) * t) + new_t = new_t * num_timesteps + return new_t + + +class WanAny2V: def __init__( self, config, checkpoint_dir, - rank=0, model_filename = None, model_type = None, base_model_type = None, @@ -63,7 +71,7 @@ class WanT2V: ): self.device = torch.device(f"cuda") self.config = config - self.rank = rank + self.VAE_dtype = VAE_dtype self.dtype = dtype self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype @@ -76,6 +84,14 @@ class WanT2V: tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), shard_fn= None) + if hasattr(config, "clip_checkpoint"): + self.clip = CLIPModel( + dtype=config.clip_dtype, + device=self.device, + checkpoint_path=os.path.join(checkpoint_dir , + config.clip_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer)) + self.vae_stride = config.vae_stride self.patch_size = config.patch_size @@ -83,22 +99,27 @@ class WanT2V: vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype, device=self.device) - logging.info(f"Creating WanModel from {model_filename[-1]}") - from mmgp import offload - # model_filename = "c:/temp/vace1.3/diffusion_pytorch_model.safetensors" - # model_filename = "Vacefusionix_quanto_fp16_int8.safetensors" - # model_filename = "c:/temp/t2v/diffusion_pytorch_model-00001-of-00006.safetensors" - # config_filename= "c:/temp/t2v/t2v.json" + # xmodel_filename = "c:/ml/multitalk/multitalk.safetensors" + # config_filename= "configs/multitalk.json" + # import json + # with open(config_filename, 'r', encoding='utf-8') as f: + # config = json.load(f) + # from mmgp import safetensors2 + # sd = safetensors2.torch_load_file(xmodel_filename) + base_config_file = f"configs/{base_model_type}.json" - forcedConfigPath = base_config_file if len(model_filename) > 1 else None + forcedConfigPath = base_config_file if len(model_filename) > 1 or base_model_type in ["flf2v_720p"] else None + # model_filename[1] = xmodel_filename self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + # self.model = offload.load_model_data(self.model, xmodel_filename ) # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") # self.model.to(torch.bfloat16) # self.model.cpu() self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - # dtype = torch.bfloat16 - # offload.load_model_data(self.model, "ckpts/Wan14BT2VFusioniX_fp16.safetensors") offload.change_dtype(self.model, dtype, True) + # offload.save_model(self.model, "multitalkbf16.safetensors", config_file_path=base_config_file, filter_sd=sd) + # offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd) + # offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file) # offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file) # offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) @@ -109,7 +130,7 @@ class WanT2V: self.sample_neg_prompt = config.sample_neg_prompt - if base_model_type in ["vace_14B", "vace_1.3B"]: + if self.model.config.get("vace_in_dim", None) != None: self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), min_area=480*832, max_area=480*832, @@ -121,6 +142,9 @@ class WanT2V: self.adapt_vace_model() + self.num_timesteps = 1000 + self.use_timestep_transform = True + def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): if ref_images is None: ref_images = [None] * len(frames) @@ -134,10 +158,11 @@ class WanT2V: reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = self.vae.encode(inactive, tile_size = tile_size) - if overlapped_latents != None and False : + if overlapped_latents != None and False : # disabled as quality seems worse # inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant for t in inactive: t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents + overlapped_latents[: 0:1] = inactive[0][: 0:1] reactive = self.vae.encode(reactive, tile_size = tile_size) latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] @@ -277,12 +302,14 @@ class WanT2V: image_sizes.append(src_video[i].shape[2:]) for k, keep in enumerate(keep_video_guide_frames): if not keep: - src_video[i][:, k:k+1] = 0 - src_mask[i][:, k:k+1] = 1 + pos = prepend_count + k + src_video[i][:, pos:pos+1] = 0 + src_mask[i][:, pos:pos+1] = 1 for k, frame in enumerate(inject_frames): if frame != None: - src_video[i][:, k:k+1], src_mask[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) + pos = prepend_count + k + src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) self.background_mask = None @@ -322,158 +349,63 @@ class WanT2V: ref_vae_latents.append(img_vae_latent[0]) return torch.cat(ref_vae_latents, dim=1) - + + def generate(self, - input_prompt, - input_frames= None, - input_masks = None, - input_ref_images = None, - input_video=None, - denoising_strength = 1.0, - target_camera=None, - context_scale=None, - width = 1280, - height = 720, - fit_into_canvas = True, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = None, - VAE_tile_size = 0, - joint_pass = False, - slg_layers = None, - slg_start = 0.0, - slg_end = 1.0, - cfg_star_switch = True, - cfg_zero_step = 5, - overlapped_latents = None, - return_latent_slice = None, - overlap_noise = 0, - conditioning_latents_size = 0, - keep_frames_parsed = [], - model_filename = None, - model_type = None, - loras_slists = None, - **bbargs + input_prompt, + input_frames= None, + input_masks = None, + input_ref_images = None, + input_video=None, + image_start = None, + image_end = None, + denoising_strength = 1.0, + target_camera=None, + context_scale=None, + width = 1280, + height = 720, + fit_into_canvas = True, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + callback = None, + enable_RIFLEx = None, + VAE_tile_size = 0, + joint_pass = False, + slg_layers = None, + slg_start = 0.0, + slg_end = 1.0, + cfg_star_switch = True, + cfg_zero_step = 5, + audio_scale=None, + audio_cfg_scale=None, + audio_proj=None, + audio_context_lens=None, + overlapped_latents = None, + return_latent_slice = None, + overlap_noise = 0, + conditioning_latents_size = 0, + keep_frames_parsed = [], + model_type = None, + loras_slists = None, + offloadobj = None, + apg_switch = False, + **bbargs ): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (tupele[`int`], *optional*, defaults to (1280,720)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess - vace = "Vace" in model_filename - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - - if self._interrupt: - return None - context = self.text_encoder([input_prompt], self.device)[0] - context_null = self.text_encoder([n_prompt], self.device)[0] - context = context.to(self.dtype) - context_null = context_null.to(self.dtype) - input_ref_images_neg = None - phantom = False - - if target_camera != None: - width = input_video.shape[2] - height = input_video.shape[1] - input_video = input_video.to(dtype=self.dtype , device=self.device) - input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device) - del input_video - # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding - cam_emb = get_camera_embedding(target_camera) - cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) - - if denoising_strength < 1. and input_frames != None: - height, width = input_frames.shape[-2:] - source_latents = self.vae.encode([input_frames])[0] - - if vace : - # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] - previous_latents = None - # if overlapped_latents != None: - # input_ref_images = [u[-1:] for u in input_ref_images] - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) - m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) - for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): - zz0[:, 0:1] = zzbg - mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - else: - if input_ref_images != None: # Phantom Ref images - phantom = True - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0), - height // self.vae_stride[1], - width // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1]) - - if self._interrupt: - return None - - noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] - - # evaluation mode - - if sample_solver == 'causvid': + + if sample_solver =="euler": + # prepare timesteps + timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32)) + timesteps.append(0.) + timesteps = [torch.tensor([t], device=self.device) for t in timesteps] + if self.use_timestep_transform: + timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + sample_scheduler = None + elif sample_solver == 'causvid': sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True) timesteps = torch.tensor([1000, 934, 862, 756, 603, 410, 250, 140, 74])[:sampling_steps].to(self.device) sample_scheduler.timesteps =timesteps @@ -496,55 +428,238 @@ class WanT2V: else: raise NotImplementedError(f"Unsupported Scheduler {sample_solver}") + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) - # sample videos - latents = noise[0] - del noise + kwargs = {'pipeline': self, 'callback': callback} - injection_denoising_step = 0 - inject_from_start = False - if denoising_strength < 1 and input_frames != None: - if len(keep_frames_parsed) == 0 or all(keep_frames_parsed): keep_frames_parsed = [] - injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) - latent_keep_frames = [] - if source_latents.shape[1] < latents.shape[1] or len(keep_frames_parsed) > 0: - inject_from_start = True - if len(keep_frames_parsed) >0 : - latent_keep_frames =[keep_frames_parsed[0]] - for i in range(1, len(keep_frames_parsed), 4): - latent_keep_frames.append(all(keep_frames_parsed[i:i+4])) + if self._interrupt: + return None + + # Text Encoder + if n_prompt == "": + n_prompt = self.sample_neg_prompt + context = self.text_encoder([input_prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + context = context.to(self.dtype) + context_null = context_null.to(self.dtype) + # from mmgp import offload + # offloadobj.unload_all() + + if self._interrupt: + return None + + vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"] + phantom = model_type in ["phantom_1.3B", "phantom_14B"] + fantasy = model_type in ["fantasy"] + multitalk = model_type in ["multitalk", "vace_multitalk_14B"] + + ref_images_count = 0 + trim_frames = 0 + extended_overlapped_latents = None + + # image2video + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + if image_start != None: + any_end_frame = False + if input_frames != None: + _ , preframes_count, height, width = input_frames.shape + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + clip_context = self.clip.visual([input_frames[:, -1:]]) #.to(self.param_dtype) + input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype) + enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width), + device=self.device, dtype= self.VAE_dtype)], + dim = 1).to(self.device) + input_frames = None else: - timesteps = timesteps[injection_denoising_step:] - if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps - if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] - injection_denoising_step = 0 + preframes_count = 1 + image_start = TF.to_tensor(image_start) + any_end_frame = image_end != None + add_frames_for_end_image = any_end_frame and model_type not in ["fun_inp_1.3B", "fun_inp", "i2v_720p"] + if any_end_frame: + image_end = TF.to_tensor(image_end) + if add_frames_for_end_image: + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + trim_frames = 1 + + h, w = image_start.shape[1:] + + h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + width, height = w, h + + lat_h = round( + h // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + w // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + clip_image_size = self.clip.model.image_size + img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype + image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) + image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype + if image_end!= None: + img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype + image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) + image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype + if image_end != None and model_type == "flf2v_720p": + clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :]]) + else: + clip_context = self.clip.visual([image_start[:, None, :, :]]) + + if any_end_frame: + enc= torch.concat([ + img_interpolated, + torch.zeros( (3, frame_num-2, h, w), device=self.device, dtype= self.VAE_dtype), + img_interpolated2, + ], dim=1).to(self.device) + else: + enc= torch.concat([ + img_interpolated, + torch.zeros( (3, frame_num-1, h, w), device=self.device, dtype= self.VAE_dtype) + ], dim=1).to(self.device) + + image_start = image_end = img_interpolated = img_interpolated2 = None + + msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) + if any_end_frame: + msk[:, preframes_count: -1] = 0 + if add_frames_for_end_image: + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) + else: + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) + else: + msk[:, preframes_count:] = 0 + msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] + overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4) + if overlapped_latents != None: + # disabled because looks worse + if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:] + extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone() + y = torch.concat([msk, lat_y]) + lat_y = None + kwargs.update({'clip_fea': clip_context, 'y': y}) + + # Recam Master + if target_camera != None: + width = input_video.shape[2] + height = input_video.shape[1] + input_video = input_video.to(dtype=self.dtype , device=self.device) + input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) + source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device) + del input_video + # Process target camera (recammaster) + from wan.utils.cammmaster_tools import get_camera_embedding + cam_emb = get_camera_embedding(target_camera) + cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) + kwargs['cam_emb'] = cam_emb + + # Video 2 Video + if denoising_strength < 1. and input_frames != None: + height, width = input_frames.shape[-2:] + source_latents = self.vae.encode([input_frames])[0] + injection_denoising_step = 0 + inject_from_start = False + if input_frames != None and denoising_strength < 1 : + if overlapped_latents != None: + overlapped_latents_frames_num = overlapped_latents.shape[1] + overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 + else: + overlapped_latents_frames_num = overlapped_frames_num = 0 + if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] + injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) + latent_keep_frames = [] + if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0: + inject_from_start = True + if len(keep_frames_parsed) >0 : + if overlapped_frames_num > 0: keep_frames_parsed = [True] * overlapped_frames_num + keep_frames_parsed + latent_keep_frames =[keep_frames_parsed[0]] + for i in range(1, len(keep_frames_parsed), 4): + latent_keep_frames.append(all(keep_frames_parsed[i:i+4])) + else: + timesteps = timesteps[injection_denoising_step:] + if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps + if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] + injection_denoising_step = 0 + + # Phantom + if phantom: + input_ref_images_neg = None + if input_ref_images != None: # Phantom Ref images + input_ref_images = self.get_vae_latents(input_ref_images, self.device) + input_ref_images_neg = torch.zeros_like(input_ref_images) + ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 + + # Vace + if vace : + # vace context encode + input_frames = [u.to(self.device) for u in input_frames] + input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] + input_masks = [u.to(self.device) for u in input_masks] + if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + if self.background_mask != None: + zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(self.background_mask, None) + for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): + zz0[:, 0:1] = zzbg + mm0[:, 0:1] = mmbg + + self.background_mask = zz0 = mm0 = zzbg = mmbg = None + z = self.vace_latent(z0, m0) + + ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + context_scale = context_scale if context_scale != None else [1.0] * len(z) + kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) + if overlapped_latents != None : + overlapped_latents_size = overlapped_latents.shape[1] + extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone() + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + lat_h, lat_w = target_shape[-2:] + height = self.vae_stride[1] * lat_h + width = self.vae_stride[2] * lat_w + + else: + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) + + if multitalk and audio_proj != None: + from wan.multitalk.multitalk import get_target_masks + audio_proj = [audio.to(self.dtype) for audio in audio_proj] + human_no = len(audio_proj[0]) + token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None + + if fantasy and audio_proj != None: + kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, }) + + + if self._interrupt: + return None + + # Ropes batch_size = 1 if target_camera != None: - shape = list(latents.shape[1:]) + shape = list(target_shape[1:]) shape[0] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: - freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx) + freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) - kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback} - - if target_camera != None: - kwargs.update({'cam_emb': cam_emb}) - - if vace: - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 - context_scale = context_scale if context_scale != None else [1.0] * len(z) - kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale}) - if overlapped_latents != None : - overlapped_latents_size = overlapped_latents.shape[1] + 1 - # overlapped_latents_size = 3 - z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z] + kwargs["freqs"] = freqs + # Steps Skipping cache_type = self.model.enable_cache if cache_type != None: - x_count = 3 if phantom else 2 + x_count = 3 if phantom or fantasy or multitalk else 2 self.model.previous_residual = [None] * x_count if cache_type == "tea": self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) @@ -552,6 +667,7 @@ class WanT2V: self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count self.model.one_for_all = x_count > 2 + if callback != None: callback(-1, None, True) @@ -560,15 +676,29 @@ class WanT2V: if chipmunk: self.model.setup_chipmunk() - updated_num_steps= len(timesteps) + # init denoising + updated_num_steps= len(timesteps) if callback != None: + from wgp import update_loras_slists update_loras_slists(self.model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) - scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} + if sample_scheduler != None: + scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} + latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + if apg_switch != 0: + apg_momentum = -0.75 + apg_norm_threshold = 55 + text_momentumbuffer = MomentumBuffer(apg_momentum) + audio_momentumbuffer = MomentumBuffer(apg_momentum) + + # denoising for i, t in enumerate(tqdm(timesteps)): - timestep = [t] + offload.set_step_no_for_lora(self.model, i) + timestep = torch.stack([t]) + kwargs.update({"t": timestep, "current_step": i}) + kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: sigma = t / 1000 @@ -585,98 +715,119 @@ class WanT2V: latents = noise * sigma + (1 - sigma) * source_latents noise = None - if overlapped_latents != None : - overlap_noise_factor = overlap_noise / 1000 + if extended_overlapped_latents != None: latent_noise_factor = t / 1000 - for zz, zz_r, ll in zip(z, z_reactive, [latents, None]): # extra None for second control net - zz[0:16, ref_images_count:overlapped_latents_size + ref_images_count] = zz_r[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(zz_r[:, ref_images_count:] ) * overlap_noise_factor - if ll != None: - ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor + latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor + if vace: + overlap_noise_factor = overlap_noise / 1000 + for zz in z: + zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor if target_camera != None: latent_model_input = torch.cat([latents, source_latents], dim=1) else: latent_model_input = latents - kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - offload.set_step_no_for_lora(self.model, i) - timestep = torch.stack(timestep) - kwargs["current_step"] = i - kwargs["t"] = timestep - if guide_scale == 1: - noise_pred = self.model( [latent_model_input], x_id = 0, context = [context], **kwargs)[0] - if self._interrupt: - return None - elif joint_pass: - if phantom: - pos_it, pos_i, neg = self.model( - [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 + - [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)], - context = [context, context_null, context_null], **kwargs) - else: - noise_pred_cond, noise_pred_uncond = self.model( - [latent_model_input, latent_model_input], context = [context, context_null], **kwargs) - if self._interrupt: - return None + if phantom: + gen_args = { + "x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 + + [ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]), + "context": [context, context_null, context_null] , + } + elif fantasy: + gen_args = { + "x" : [latent_model_input, latent_model_input, latent_model_input], + "context" : [context, context_null, context_null], + "audio_scale": [audio_scale, None, None ] + } + elif multitalk: + gen_args = { + "x" : [latent_model_input, latent_model_input, latent_model_input], + "context" : [context, context_null, context_null], + "multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]], + "multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None] + } else: - if phantom: - pos_it = self.model( - [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs - )[0] - if self._interrupt: - return None - pos_i = self.model( - [ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs - )[0] - if self._interrupt: - return None - neg = self.model( - [ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs - )[0] - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - [latent_model_input], x_id = 0, context = [context], **kwargs)[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - [latent_model_input], x_id = 1, context = [context_null], **kwargs)[0] - if self._interrupt: - return None + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context": [context, context_null] + } - # del latent_model_input - - # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + if joint_pass and guide_scale > 1: + ret_values = self.model( **gen_args , **kwargs) + if self._interrupt: + return None + else: + size = 1 if guide_scale == 1 else len(gen_args["x"]) + ret_values = [None] * size + for x_id in range(size): + sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() } + ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0] + if self._interrupt: + return None + sub_gen_args = None if guide_scale == 1: - pass + noise_pred = ret_values[0] elif phantom: guide_scale_img= 5.0 - guide_scale_text= guide_scale #7.5 + guide_scale_text= guide_scale #7.5 + pos_it, pos_i, neg = ret_values noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i) + pos_it = pos_i = neg = None + elif fantasy: + noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = ret_values + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) + noise_pred_noaudio = None + elif multitalk: + noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values + if apg_switch != 0: + noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text, + noise_pred_cond, + momentum_buffer=text_momentumbuffer, + norm_threshold=apg_norm_threshold) \ + + (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond, + noise_pred_cond, + momentum_buffer=audio_momentumbuffer, + norm_threshold=apg_norm_threshold) + else: + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond) + noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = None else: - noise_pred_text = noise_pred_cond - if cfg_star_switch: - positive_flat = noise_pred_text.view(batch_size, -1) - negative_flat = noise_pred_uncond.view(batch_size, -1) + noise_pred_cond, noise_pred_uncond = ret_values + if apg_switch != 0: + noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_uncond, + noise_pred_cond, + momentum_buffer=text_momentumbuffer, + norm_threshold=apg_norm_threshold) + else: + noise_pred_text = noise_pred_cond + if cfg_star_switch: + # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) - alpha = optimized_scale(positive_flat,negative_flat) - alpha = alpha.view(batch_size, 1, 1, 1) + alpha = optimized_scale(positive_flat,negative_flat) + alpha = alpha.view(batch_size, 1, 1, 1) - if (i <= cfg_zero_step): - noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... - else: - noise_pred_uncond *= alpha - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) - noise_pred_uncond, noise_pred_cond, noise_pred_text, pos_it, pos_i, neg = None, None, None, None, None, None - temp_x0 = sample_scheduler.step( - noise_pred[:, :target_shape[1]].unsqueeze(0), - t, - latents.unsqueeze(0), - # return_dict=False, - **scheduler_kwargs)[0] - latents = temp_x0.squeeze(0) - del temp_x0 + if (i <= cfg_zero_step): + noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... + else: + noise_pred_uncond *= alpha + noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) + ret_values = noise_pred_uncond = noise_pred_cond = noise_pred_text = neg = None + + if sample_solver == "euler": + dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) + dt = dt / self.num_timesteps + latents = latents - noise_pred * dt[:, None, None, None] + else: + temp_x0 = sample_scheduler.step( + noise_pred[:, :target_shape[1]].unsqueeze(0), + t, + latents.unsqueeze(0), + **scheduler_kwargs)[0] + latents = temp_x0.squeeze(0) + del temp_x0 if callback is not None: callback(i, latents, False) @@ -684,23 +835,19 @@ class WanT2V: x0 = [latents] if chipmunk: - self.model.release_chipmunk() # need to add it at every exit when in prof + self.model.release_chipmunk() # need to add it at every exit when in prod if return_latent_slice != None: - if overlapped_latents != None: - # latents [:, 1:] = self.toto - for zz, zz_r, ll in zip(z, z_reactive, [latents]): - ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r - latent_slice = latents[:, return_latent_slice].clone() - if input_frames == None: - if phantom: - # phantom post processing - x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0] - videos = self.vae.decode(x0, VAE_tile_size) - else: + if vace: # vace post processing videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) + else: + if phantom and input_ref_images != None: + trim_frames = input_ref_images.shape[1] + if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0] + videos = self.vae.decode(x0, VAE_tile_size) + if return_latent_slice != None: return { "x" : videos[0], "latent_slice" : latent_slice } return videos[0] diff --git a/wan/image2video.py b/wan/image2video.py deleted file mode 100644 index 7897134..0000000 --- a/wan/image2video.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import gc -import logging -import math -import os -import random -import sys -import types -from contextlib import contextmanager -from functools import partial -import json -import numpy as np -import torch -import torch.cuda.amp as amp -import torch.distributed as dist -import torchvision.transforms.functional as TF -from tqdm import tqdm - -from .distributed.fsdp import shard_model -from .modules.clip import CLIPModel -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from wan.utils.utils import resize_lanczos, calculate_new_dimensions -from wan.utils.basic_flowmatch import FlowMatchScheduler - -def optimized_scale(positive_flat, negative_flat): - - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - - return st_star - - - -class WanI2V: - - def __init__( - self, - config, - checkpoint_dir, - model_filename = None, - model_type = None, - base_model_type= None, - text_encoder_filename= None, - quantizeTransformer = False, - dtype = torch.bfloat16, - VAE_dtype = torch.float32, - save_quantized = False, - mixed_precision_transformer = False - ): - self.device = torch.device(f"cuda") - self.config = config - self.dtype = dtype - self.VAE_dtype = VAE_dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - # shard_fn = partial(shard_model, device_id=device_id) - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn=None, - ) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype = VAE_dtype, - device=self.device) - - self.clip = CLIPModel( - dtype=config.clip_dtype, - device=self.device, - checkpoint_path=os.path.join(checkpoint_dir , - config.clip_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer)) - - logging.info(f"Creating WanModel from {model_filename[-1]}") - from mmgp import offload - - # fantasy = torch.load("c:/temp/fantasy.ckpt") - # proj_model = fantasy["proj_model"] - # audio_processor = fantasy["audio_processor"] - # offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors") - # offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors") - # for k,v in audio_processor.items(): - # audio_processor[k] = v.to(torch.bfloat16) - # with open("fantasy_config.json", "r", encoding="utf-8") as reader: - # config_text = reader.read() - # config_json = json.loads(config_text) - # offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json) - # model_filename = [model_filename, "audio_processor_bf16.safetensors"] - # model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors" - # dtype = torch.float16 - base_config_file = f"configs/{base_model_type}.json" - forcedConfigPath = base_config_file if len(model_filename) > 1 else None - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath= base_config_file, forcedConfigPath= forcedConfigPath) - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - offload.change_dtype(self.model, dtype, True) - # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json") - # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") - # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json") - - # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") - self.model.eval().requires_grad_(False) - if save_quantized: - from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) - - - self.sample_neg_prompt = config.sample_neg_prompt - - def generate(self, - input_prompt, - image_start, - image_end = None, - height =720, - width = 1280, - fit_into_canvas = True, - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=40, - guide_scale=5.0, - n_prompt="", - seed=-1, - callback = None, - enable_RIFLEx = False, - VAE_tile_size= 0, - joint_pass = False, - slg_layers = None, - slg_start = 0.0, - slg_end = 1.0, - cfg_star_switch = True, - cfg_zero_step = 5, - audio_scale=None, - audio_cfg_scale=None, - audio_proj=None, - audio_context_lens=None, - model_filename = None, - offloadobj = None, - **bbargs - ): - r""" - Generates video frames from input image and text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation. - image_start (PIL.Image.Image): - Input image tensor. Shape: [3, H, W] - max_area (`int`, *optional*, defaults to 720*1280): - Maximum pixel area for latent space calculation. Controls video resolution scaling - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from max_area) - - W: Frame width from max_area) - """ - - add_frames_for_end_image = "image2video" in model_filename or "fantasy" in model_filename - - image_start = TF.to_tensor(image_start) - lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1) - any_end_frame = image_end !=None - if any_end_frame: - any_end_frame = True - image_end = TF.to_tensor(image_end) - if add_frames_for_end_image: - frame_num +=1 - lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - - h, w = image_start.shape[1:] - - h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - - lat_h = round( - h // self.vae_stride[1] // - self.patch_size[1] * self.patch_size[1]) - lat_w = round( - w // self.vae_stride[2] // - self.patch_size[2] * self.patch_size[2]) - h = lat_h * self.vae_stride[1] - w = lat_w * self.vae_stride[2] - - clip_image_size = self.clip.model.image_size - img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - if image_end!= None: - img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - - max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) - - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - noise = torch.randn(16, lat_frames, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) - - msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) - if any_end_frame: - msk[:, 1: -1] = 0 - if add_frames_for_end_image: - msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1) - else: - msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) - - else: - msk[:, 1:] = 0 - msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) - msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) - msk = msk.transpose(1, 2)[0] - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - - if self._interrupt: - return None - - # preprocess - context = self.text_encoder([input_prompt], self.device)[0] - context_null = self.text_encoder([n_prompt], self.device)[0] - context = context.to(self.dtype) - context_null = context_null.to(self.dtype) - - if self._interrupt: - return None - - clip_context = self.clip.visual([image_start[:, None, :, :]]) - - from mmgp import offload - offloadobj.unload_all() - if any_end_frame: - mean2 = 0 - enc= torch.concat([ - img_interpolated, - torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.VAE_dtype), - img_interpolated2, - ], dim=1).to(self.device) - else: - enc= torch.concat([ - img_interpolated, - torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.VAE_dtype) - ], dim=1).to(self.device) - image_start, image_end, img_interpolated, img_interpolated2 = None, None, None, None - - lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] - y = torch.concat([msk, lat_y]) - lat_y = None - - - # evaluation mode - if sample_solver == 'causvid': - sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True) - timesteps = torch.tensor([1000, 934, 862, 756, 603, 410, 250, 140, 74])[:sampling_steps].to(self.device) - sample_scheduler.timesteps =timesteps - sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.], device=self.device)]) - elif sample_solver == 'unipc' or sample_solver == "": - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported scheduler.") - - # sample videos - latent = noise - batch_size = 1 - freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx) - - kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback } - - if audio_proj != None: - kwargs.update({ - "audio_proj": audio_proj.to(self.dtype), - "audio_context_lens": audio_context_lens, - }) - cache_type = self.model.enable_cache - if cache_type != None: - x_count = 3 if audio_cfg_scale !=None else 2 - self.model.previous_residual = [None] * x_count - if cache_type == "tea": - self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) - else: - self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier) - self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count - self.model.one_for_all = x_count > 2 - - # self.model.to(self.device) - if callback != None: - callback(-1, None, True) - latent = latent.to(self.device) - for i, t in enumerate(tqdm(timesteps)): - offload.set_step_no_for_lora(self.model, i) - kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - latent_model_input = latent - timestep = [t] - - timestep = torch.stack(timestep).to(self.device) - kwargs.update({ - 't' :timestep, - 'current_step' :i, - }) - - if guide_scale == 1: - noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0] - if self._interrupt: - return None - elif joint_pass: - if audio_proj == None: - noise_pred_cond, noise_pred_uncond = self.model( - [latent_model_input, latent_model_input], - context=[context, context_null], - **kwargs) - else: - noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model( - [latent_model_input, latent_model_input, latent_model_input], - context=[context, context, context_null], - audio_scale = [audio_scale, None, None ], - **kwargs) - - if self._interrupt: - return None - else: - noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0] - if self._interrupt: - return None - - if audio_proj != None: - noise_pred_noaudio = self.model( - [latent_model_input], - x_id=1, - context=[context], - **kwargs, - )[0] - if self._interrupt: - return None - - noise_pred_uncond = self.model( - [latent_model_input], - x_id=1 if audio_scale == None else 2, - context=[context_null], - **kwargs, - )[0] - if self._interrupt: - return None - del latent_model_input - - if guide_scale > 1: - # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - if cfg_star_switch: - positive_flat = noise_pred_cond.view(batch_size, -1) - negative_flat = noise_pred_uncond.view(batch_size, -1) - - alpha = optimized_scale(positive_flat,negative_flat) - alpha = alpha.view(batch_size, 1, 1, 1) - - if (i <= cfg_zero_step): - noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred... - else: - noise_pred_uncond *= alpha - if audio_scale == None: - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) - else: - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio) - - noise_pred_uncond, noise_pred_noaudio = None, None - temp_x0 = sample_scheduler.step( - noise_pred.unsqueeze(0), - t, - latent.unsqueeze(0), - return_dict=False, - generator=seed_g)[0] - latent = temp_x0.squeeze(0) - del temp_x0 - del timestep - - if callback is not None: - callback(i, latent, False) - - x0 = [latent] - video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] - - if any_end_frame and add_frames_for_end_image: - # video[:, -1:] = img_interpolated2 - video = video[:, :-1] - - del noise, latent - del sample_scheduler - - return video diff --git a/wan/modules/clip.py b/wan/modules/clip.py index fc41d85..da91a00 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -531,7 +531,7 @@ class CLIPModel: seq_len=self.model.max_text_len - 2, clean='whitespace') - def visual(self, videos): + def visual(self, videos,): # preprocess size = (self.model.image_size,) * 2 videos = torch.cat([ diff --git a/wan/modules/model.py b/wan/modules/model.py index d276b2f..ab352f6 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -11,6 +11,7 @@ from typing import Union,Optional from mmgp import offload from .attention import pay_attention from torch.backends.cuda import sdp_kernel +from wan.multitalk.multitalk_utils import get_attn_map_with_target __all__ = ['WanModel'] @@ -175,7 +176,7 @@ class WanSelfAttention(nn.Module): self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, grid_sizes, freqs, block_mask = None): + def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -190,7 +191,7 @@ class WanSelfAttention(nn.Module): # query, key, value function q = self.q(x) self.norm_q(q) - q = q.view(b, s, n, d) # !!! + q = q.view(b, s, n, d) k = self.k(x) self.norm_k(k) k = k.view(b, s, n, d) @@ -200,6 +201,12 @@ class WanSelfAttention(nn.Module): del q,k q,k = apply_rotary_emb(qklist, freqs, head_first=False) + + if ref_target_masks != None: + x_ref_attn_map = get_attn_map_with_target(q, k , grid_sizes, ref_target_masks=ref_target_masks, ref_images_count = ref_images_count) + else: + x_ref_attn_map = None + chipmunk = offload.shared_state.get("_chipmunk", False) if chipmunk and self.__class__ == WanSelfAttention: q = q.transpose(1,2) @@ -225,30 +232,10 @@ class WanSelfAttention(nn.Module): ) del q,k,v - # if not self._flag_ar_attention: - # q = rope_apply(q, grid_sizes, freqs) - # k = rope_apply(k, grid_sizes, freqs) - # x = flash_attention(q=q, k=k, v=v, window_size=self.window_size) - # else: - # q = rope_apply(q, grid_sizes, freqs) - # k = rope_apply(k, grid_sizes, freqs) - # q = q.to(torch.bfloat16) - # k = k.to(torch.bfloat16) - # v = v.to(torch.bfloat16) - # with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): - # x = ( - # torch.nn.functional.scaled_dot_product_attention( - # q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask - # ) - # .transpose(1, 2) - # .contiguous() - # ) - - # output x = x.flatten(2) x = self.o(x) - return x + return x, x_ref_attn_map class WanT2VCrossAttention(WanSelfAttention): @@ -375,7 +362,11 @@ class WanAttentionBlock(nn.Module): cross_attn_norm=False, eps=1e-6, block_id=None, - block_no = 0 + block_no = 0, + output_dim=0, + norm_input_visual=True, + class_range=24, + class_interval=4, ): super().__init__() self.dim = dim @@ -409,6 +400,22 @@ class WanAttentionBlock(nn.Module): self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) self.block_id = block_id + if output_dim > 0: + from wan.multitalk.attention import SingleStreamMutiAttention + # init audio module + self.audio_cross_attn = SingleStreamMutiAttention( + dim=dim, + encoder_hidden_states_dim=output_dim, + num_heads=num_heads, + qk_norm=False, + qkv_bias=True, + eps=eps, + norm_layer=WanRMSNorm, + class_range=class_range, + class_interval=class_interval + ) + self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity() + def forward( self, x, @@ -423,6 +430,9 @@ class WanAttentionBlock(nn.Module): audio_proj= None, audio_context_lens= None, audio_scale=None, + multitalk_audio=None, + multitalk_masks=None, + ref_images_count=0, ): r""" Args: @@ -466,11 +476,10 @@ class WanAttentionBlock(nn.Module): xlist = [x_mod.to(attention_dtype)] del x_mod - y = self.self_attn( xlist, grid_sizes, freqs, block_mask) + y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count) y = y.to(dtype) - if cam_emb != None: - y = self.projector(y) + if cam_emb != None: y = self.projector(y) x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames) x.addcmul_(y, e[2]) @@ -482,6 +491,25 @@ class WanAttentionBlock(nn.Module): del y x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) + if multitalk_audio != None: + # cross attn of multitalk audio + y = self.norm_x(x) + y = y.to(attention_dtype) + if ref_images_count == 0: + x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) + else: + y_shape = y.shape + y = y.reshape(y_shape[0], grid_sizes[0], -1) + y = y[:, ref_images_count:] + y = y.reshape(y_shape[0], -1, y_shape[-1]) + grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] + y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) + y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) + x = x.reshape(y_shape[0], grid_sizes[0], -1) + x[:, ref_images_count:] += y + x = x.reshape(y_shape[0], -1, y_shape[-1]) + del y + y = self.norm2(x) y = reshape_latent(y , latent_frames) @@ -518,6 +546,71 @@ class WanAttentionBlock(nn.Module): x.add_(hint, alpha= scale) return x +class AudioProjModel(ModelMixin, ConfigMixin): + def __init__( + self, + seq_len=5, + seq_len_vf=12, + blocks=12, + channels=768, + intermediate_dim=512, + output_dim=768, + context_tokens=32, + norm_output_audio=False, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.proj1 = nn.Linear(self.input_dim, intermediate_dim) + self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim) + self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) + self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) + self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity() + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + audio_embeds_vf = audio_embeds = None + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim) + audio_embeds_c = None + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + class VaceWanAttentionBlock(WanAttentionBlock): @@ -595,19 +688,27 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim): + def __init__(self, in_dim, out_dim, flf_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) + + if flf_pos_emb: # NOTE: we only use this for `flf2v` + FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 + self.emb_pos = nn.Parameter( + torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) def forward(self, image_embeds): + if hasattr(self, 'emb_pos'): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens - class WanModel(ModelMixin, ConfigMixin): def setup_chipmunk(self): # from chipmunk.util import LayerCounter @@ -696,45 +797,18 @@ class WanModel(ModelMixin, ConfigMixin): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf = False, recammaster = False, inject_sample_info = False, fantasytalking_dim = 0, + multitalk_output_dim = 0, + audio_window=5, + intermediate_dim=512, + context_tokens=32, + vae_scale=4, # vae timedownsample scale + norm_input_visual=True, + norm_output_audio=True, ): - r""" - Initialize the diffusion model backbone. - - Args: - model_type (`str`, *optional*, defaults to 't2v'): - Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) - patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): - 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) - text_len (`int`, *optional*, defaults to 512): - Fixed length for text embeddings - in_dim (`int`, *optional*, defaults to 16): - Input video channels (C_in) - dim (`int`, *optional*, defaults to 2048): - Hidden dimension of the transformer - ffn_dim (`int`, *optional*, defaults to 8192): - Intermediate dimension in feed-forward network - freq_dim (`int`, *optional*, defaults to 256): - Dimension for sinusoidal time embeddings - text_dim (`int`, *optional*, defaults to 4096): - Input dimension for text embeddings - out_dim (`int`, *optional*, defaults to 16): - Output video channels (C_out) - num_heads (`int`, *optional*, defaults to 16): - Number of attention heads - num_layers (`int`, *optional*, defaults to 32): - Number of transformer blocks - window_size (`tuple`, *optional*, defaults to (-1, -1)): - Window size for local attention (-1 indicates global attention) - qk_norm (`bool`, *optional*, defaults to True): - Enable query/key normalization - cross_attn_norm (`bool`, *optional*, defaults to False): - Enable cross-attention normalization - eps (`float`, *optional*, defaults to 1e-6): - Epsilon value for normalization layers - """ super().__init__() @@ -760,6 +834,14 @@ class WanModel(ModelMixin, ConfigMixin): self.block_mask = None self.inject_sample_info = inject_sample_info + self.norm_output_audio = norm_output_audio + self.audio_window = audio_window + self.intermediate_dim = intermediate_dim + self.vae_scale = vae_scale + + multitalk = multitalk_output_dim > 0 + self.multitalk = multitalk + # embeddings self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -780,7 +862,7 @@ class WanModel(ModelMixin, ConfigMixin): cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, block_no =i) + window_size, qk_norm, cross_attn_norm, eps, block_no =i, output_dim=multitalk_output_dim, norm_input_visual=norm_input_visual) for i in range(num_layers) ]) @@ -790,7 +872,18 @@ class WanModel(ModelMixin, ConfigMixin): # buffers (don't use register_buffer otherwise dtype will be changed in to()) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim) + self.img_emb = MLPProj(1280, dim, flf_pos_emb = flf) + + if multitalk : + # init audio adapter + self.audio_proj = AudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + output_dim=multitalk_output_dim, + context_tokens=context_tokens, + norm_output_audio=norm_output_audio, + ) # initialize weights self.init_weights() @@ -806,7 +899,10 @@ class WanModel(ModelMixin, ConfigMixin): self.blocks = nn.ModuleList([ WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_no =i, - block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None, + output_dim=multitalk_output_dim, + norm_input_visual=norm_input_visual, + ) for i in range(self.num_layers) ]) @@ -847,6 +943,10 @@ class WanModel(ModelMixin, ConfigMixin): for block in self.blocks: layer_list2 += [block.norm3] + if hasattr(self, "audio_proj"): + for block in self.blocks: + layer_list2 += [block.norm_x] + if hasattr(self, "fps_embedding"): layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]] @@ -1006,6 +1106,9 @@ class WanModel(ModelMixin, ConfigMixin): audio_proj=None, audio_context_lens=None, audio_scale=None, + multitalk_audio = None, + multitalk_masks = None, + ref_images_count = 0, ): # patch_dtype = self.patch_embedding.weight.dtype @@ -1090,6 +1193,21 @@ class WanModel(ModelMixin, ConfigMixin): context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ] + if multitalk_audio != None: + multitalk_audio_list = [] + for audio in multitalk_audio: + audio = self.audio_proj(*audio) + audio = torch.concat(audio.split(1), dim=2).to(context[0]) + multitalk_audio_list.append(audio) + audio = None + else: + multitalk_audio_list = [None] * len(x_list) + + if multitalk_masks != None: + multitalk_masks_list = multitalk_masks + else: + multitalk_masks_list = [None] * len(x_list) + context_list = context if audio_scale != None: audio_scale_list = audio_scale @@ -1105,6 +1223,7 @@ class WanModel(ModelMixin, ConfigMixin): block_mask = block_mask, audio_proj=audio_proj, audio_context_lens=audio_context_lens, + ref_images_count=ref_images_count, ) if vace_context == None: @@ -1137,7 +1256,7 @@ class WanModel(ModelMixin, ConfigMixin): if self.accumulated_err[cur_x_id] 0: +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): + if rm_background: session = new_session() output_list =[] @@ -183,7 +183,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg new_height = int( round(height * scale / 16) * 16) new_width = int( round(width * scale / 16) * 16) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) - if rm_background == 1 or rm_background == 2 and i > 0 : + if rm_background and not (ignore_first and i == 0) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, @@ -406,3 +406,137 @@ def create_progress_hook(filename): return progress_hook(block_num, block_size, total_size, filename) return hook +import ffmpeg +import os +import tempfile + +def extract_audio_tracks(source_video, verbose=False, query_only= False): + """ + Extract all audio tracks from source video to temporary files. + + Args: + source_video: Path to video with audio to extract + verbose: Enable verbose output (default: False) + + Returns: + List of temporary audio file paths, or empty list if no audio tracks + """ + try: + # Check if source video has audio + probe = ffmpeg.probe(source_video) + audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] + + if not audio_streams: + if query_only: return 0 + if verbose: + print(f"No audio track found in {source_video}") + return [] + if query_only: return len(audio_streams) + if verbose: + print(f"Found {len(audio_streams)} audio track(s)") + + # Create temporary audio files for each track + temp_audio_files = [] + for i in range(len(audio_streams)): + fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') + os.close(fd) # Close file descriptor immediately + temp_audio_files.append(temp_path) + + # Extract each audio track + for i, temp_path in enumerate(temp_audio_files): + (ffmpeg + .input(source_video) + .output(temp_path, **{f'map': f'0:a:{i}', 'acodec': 'aac'}) + .overwrite_output() + .run(quiet=not verbose)) + + return temp_audio_files + + except ffmpeg.Error as e: + print(f"FFmpeg error during audio extraction: {e}") + return [] + except Exception as e: + print(f"Error during audio extraction: {e}") + return [] + +def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False): + """ + Combine video with audio tracks. Output duration matches video length exactly. + + Args: + target_video: Path to video to receive the audio + audio_tracks: List of audio file paths to combine + output_video: Path for the output video + verbose: Enable verbose output (default: False) + + Returns: + True if successful, False otherwise + """ + if not audio_tracks: + if verbose: + print("No audio tracks to combine") + return False + + try: + # Get video duration to ensure exact alignment + video_probe = ffmpeg.probe(target_video) + video_duration = float(video_probe['streams'][0]['duration']) + + if verbose: + print(f"Target video duration: {video_duration:.3f} seconds") + + # Combine target video with all audio tracks, force video duration + video = ffmpeg.input(target_video).video + audio_inputs = [ffmpeg.input(audio_path).audio for audio_path in audio_tracks] + + # Create output with video duration as master timing + inputs = [video] + audio_inputs + (ffmpeg + .output(*inputs, output_video, + vcodec='copy', + acodec='copy', + t=video_duration) # Force exact video duration + .overwrite_output() + .run(quiet=not verbose)) + + if verbose: + print(f"Successfully created {output_video} with {len(audio_tracks)} audio track(s) aligned to video duration") + return True + + except ffmpeg.Error as e: + print(f"FFmpeg error during video combination: {e}") + return False + except Exception as e: + print(f"Error during video combination: {e}") + return False + +def cleanup_temp_audio_files(audio_tracks, verbose=False): + """ + Clean up temporary audio files. + + Args: + audio_tracks: List of audio file paths to delete + verbose: Enable verbose output (default: False) + + Returns: + Number of files successfully deleted + """ + deleted_count = 0 + + for audio_path in audio_tracks: + try: + if os.path.exists(audio_path): + os.unlink(audio_path) + deleted_count += 1 + if verbose: + print(f"Cleaned up {audio_path}") + except PermissionError: + print(f"Warning: Could not delete {audio_path} (file may be in use)") + except Exception as e: + print(f"Warning: Error deleting {audio_path}: {e}") + + if verbose and deleted_count > 0: + print(f"Successfully deleted {deleted_count} temporary audio file(s)") + + return deleted_count + diff --git a/wgp.py b/wgp.py index 7781579..cac3cff 100644 --- a/wgp.py +++ b/wgp.py @@ -16,7 +16,9 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.utils import cache_video, convert_tensor_to_image, save_image +from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date +from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + from wan.modules.attention import get_attention_modes, get_supported_attention_modes from huggingface_hub import hf_hub_download, snapshot_download import torch @@ -48,8 +50,9 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.1" -WanGP_version = "6.51" -settings_version = 2.1 +WanGP_version = "6.6" +settings_version = 2.21 +max_source_video_frames = 1000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None from importlib.metadata import version @@ -175,8 +178,11 @@ def process_prompt_and_add_tasks(state, model_choice): if mode == "edit": edit_video_source =gen.get("edit_video_source", None) edit_overrides =gen.get("edit_overrides", None) - - for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "video_mask"]: + _ , _ , _, frames_count = get_video_info(edit_video_source) + if frames_count > max_source_video_frames: + gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated") + # return + for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask"]: inputs[k] = None inputs.update(edit_overrides) del gen["edit_video_source"], gen["edit_overrides"] @@ -204,7 +210,9 @@ def process_prompt_and_add_tasks(state, model_choice): queue= gen.get("queue", []) return update_queue_data(queue) - + if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0: + gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time") + return prompt = inputs["prompt"] if len(prompt) ==0: gr.Info("Prompt cannot be empty.") @@ -230,11 +238,14 @@ def process_prompt_and_add_tasks(state, model_choice): image_start = inputs["image_start"] image_end = inputs["image_end"] image_refs = inputs["image_refs"] - audio_guide = inputs["audio_guide"] image_prompt_type = inputs["image_prompt_type"] + audio_prompt_type = inputs["audio_prompt_type"] if image_prompt_type == None: image_prompt_type = "" video_prompt_type = inputs["video_prompt_type"] if video_prompt_type == None: video_prompt_type = "" + force_fps = inputs["force_fps"] + audio_guide = inputs["audio_guide"] + audio_guide2 = inputs["audio_guide2"] video_guide = inputs["video_guide"] video_mask = inputs["video_mask"] video_source = inputs["video_source"] @@ -268,7 +279,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info(f"Invalid Frame Position '{pos_str}'") return pos = int(pos_str) - if pos <1 or pos > 1000: + if pos <1 or pos > max_source_video_frames: gr.Info(f"Invalid Frame Position Value'{pos_str}'") return else: @@ -288,6 +299,28 @@ def process_prompt_and_add_tasks(state, model_choice): else: video_source = None + if "A" in audio_prompt_type: + if audio_guide == None: + gr.Info("You must provide an Audio Source") + return + if "B" in audio_prompt_type: + if audio_guide2 == None: + gr.Info("You must provide a second Audio Source") + return + else: + audio_guide2 = None + else: + audio_guide = None + audio_guide2 = None + + if model_type in ["vace_multitalk_14B"] and ("B" in audio_prompt_type or "X" in audio_prompt_type): + if not "I" in video_prompt_type: + gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame that contains the two people one on each side ") + + if "R" in audio_prompt_type and len(filter_letters(image_prompt_type, "VLG")) > 0 : + gr.Info("Remuxing is not yet supported if there is a video source") + audio_prompt_type= replace("R" ,"") + if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: if image_refs == None : gr.Info("You must provide an Image Reference") @@ -365,14 +398,17 @@ def process_prompt_and_add_tasks(state, model_choice): if test_any_sliding_window(model_type): if video_length > sliding_window_size: - no_windows = compute_sliding_window_no(video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) - gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") + full_video_length = video_length if video_source is None else video_length + sliding_window_overlap + extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" + no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) + gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") if "recam" in model_filename: if video_source == None: gr.Info("You must provide a Source Video") return - frames = get_resampled_video(video_source, 0, 81, 16) + + frames = get_resampled_video(video_source, 0, 81, get_computed_fps(force_fps, model_type , video_guide, video_source )) if len(frames)<81: gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done") return @@ -388,10 +424,6 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Filtering Frames with this model is not supported") return - if "hunyuan_video_avatar" in model_filename and audio_guide == None: - gr.Info("You must provide an audio file") - return - if inputs["multi_prompts_gen_type"] != 0: if image_start != None and len(image_start) > 1: gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") @@ -406,6 +438,7 @@ def process_prompt_and_add_tasks(state, model_choice): "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None, "image_refs": image_refs, "audio_guide": audio_guide, + "audio_guide2": audio_guide2, "video_guide": video_guide, "video_mask": video_mask, "video_source": video_source, @@ -415,6 +448,7 @@ def process_prompt_and_add_tasks(state, model_choice): "denoising_strength": denoising_strength, "image_prompt_type": image_prompt_type, "video_prompt_type": video_prompt_type, + "audio_prompt_type": audio_prompt_type, } if inputs["multi_prompts_gen_type"] == 0: @@ -612,7 +646,7 @@ def save_queue_action(state): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] for key in image_keys: images_pil = params_copy.get(key) @@ -788,7 +822,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): params['state'] = state image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] loaded_pil_images = {} loaded_video_paths = {} @@ -1008,7 +1042,7 @@ def autosave_queue(): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] for key in image_keys: images_pil = params_copy.get(key) @@ -1179,6 +1213,12 @@ def _parse_args(): help="save proprocessed masks for debugging or editing" ) + parser.add_argument( + "--save-speakers", + action="store_true", + help="save proprocessed audio track with extract speakers for debugging or editing" + ) + parser.add_argument( "--share", action="store_true", @@ -1525,7 +1565,6 @@ lock_ui_compile = False force_profile_no = int(args.profile) verbose_level = int(args.verbose) check_loras = args.check_loras ==1 -advanced = args.advanced server_config_filename = "wgp_config.json" if not os.path.isdir("settings"): @@ -1595,16 +1634,22 @@ wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_ "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors", - "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors", - "ckpts/wan2.1_moviigen1.1_14B_mbf16.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mfp16_int8.safetensors", + "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors", ] wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", - "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors", - "ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"] + "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors", + ] ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"] +modules_files = { + "vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"], + "fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"], + "multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"] +} + + hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16v2.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors", "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors", "ckpts/hunyuan_video_custom_audio_720_bf16.safetensors", "ckpts/hunyuan_video_custom_audio_720_quanto_bf16_int8.safetensors", @@ -1614,33 +1659,29 @@ hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_vid transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan_choices def get_dependent_models(model_type, quantization, dtype_policy ): - if model_type == "fantasy": - dependent_model_type = "i2v_720p" - elif model_type == "ltxv_13B_distilled": + # if model_type == "fantasy": + # dependent_model_type = "i2v_720p" + if model_type == "ltxv_13B_distilled": dependent_model_type = "ltxv_13B" - elif model_type == "vace_14B": - dependent_model_type = "t2v" + # elif model_type == "vace_14B": + # dependent_model_type = "t2v" else: return [], [] return [get_model_filename(dependent_model_type, quantization, dtype_policy)], [dependent_model_type] -model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", +abstract_model_types = ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B", "flf2v_720p"] +model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "vace_1.3B", "phantom_1.3B", "phantom_14B", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", - "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] + "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] + model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", - "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", - "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", - "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "moviigen" :"moviigen", - "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "fantasy" : "fantasy", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", + "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B","recam_1.3B": "recammaster_1.3B", + "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", + "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", + "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom_720", "hunyuan_custom_audio" : "hunyuan_video_custom_audio", "hunyuan_custom_edit" : "hunyuan_video_custom_edit", "hunyuan_avatar" : "hunyuan_video_avatar" } -def are_model_types_compatible(model_type1, model_type2): - return get_base_model_type(model_type1) == get_base_model_type(model_type2) - -def get_model_finetune_def(model_type): - return finetunes.get(model_type, None ) - def get_base_model_type(model_type): finetune_def = get_model_finetune_def(model_type) if finetune_def == None: @@ -1649,6 +1690,38 @@ def get_base_model_type(model_type): else: return finetune_def["architecture"] +def are_model_types_compatible(imported_model_type, current_model_type): + imported_base_model_type = get_base_model_type(imported_model_type) + curent_base_model_type = get_base_model_type(current_model_type) + if imported_base_model_type == curent_base_model_type: + return True + + eqv_map = { + "i2v_720p" : "i2v", + "flf2v_720p" : "i2v", + "t2v_1.3B" : "t2v", + "sky_df_1.3B" : "sky_df_14B", + "sky_df_720p_14B" : "sky_df_14B", + } + if imported_base_model_type in eqv_map: + imported_base_model_type = eqv_map[imported_base_model_type] + comp_map = { + "vace_14B" : [ "vace_multitalk_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"], + "i2v" : [ "fantasy", "multitalk", "i2v_720p","flf2v_720p" ], + "ltxv_13B_distilled": ["ltxv_13B"], + "fantasy": ["multitalk"], + "sky_df_14B": ["sky_df_1.3B", "sky_df_720p_14B"], + "hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"], + } + comp_list= comp_map.get(imported_base_model_type, None) + if comp_list == None: return False + return curent_base_model_type in comp_list + +def get_model_finetune_def(model_type): + return finetunes.get(model_type, None ) + + def get_model_type(model_filename): for model_type, signature in model_signatures.items(): @@ -1666,14 +1739,53 @@ def get_model_family(model_type): else: return "wan" +def test_class_i2v(model_type): + model_type = get_base_model_type(model_type) + return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v", "multitalk" ] + +def test_vace_module(model_type): + model_type = get_base_model_type(model_type) + return model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"] def test_any_sliding_window(model_type): model_type = get_base_model_type(model_type) - return model_type in ["vace_1.3B","vace_14B","sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled"] + return test_vace_module(model_type) or model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "multitalk", "t2v", "fantasy"] or test_class_i2v(model_type) -def test_class_i2v(model_type): +def get_model_min_frames_and_step(model_type): model_type = get_base_model_type(model_type) - return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ] + if model_type in ["sky_df_14B", "sky_df_720p_14B"]: + return 17, 20 + elif model_type in ["ltxv_13B", "ltxv_13B_distilled"]: + return 17, 8 + elif test_vace_module(model_type): + return 17, 4 + else: + return 5, 4 + +def get_model_fps(model_type): + model_type = get_base_model_type(model_type) + if model_type in ["hunyuan_avatar", "hunyuan_custom_audio", "multitalk", "vace_multitalk_14B"]: + fps = 25 + elif model_type in ["sky_df_14B", "sky_df_720p_14B", "hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]: + fps = 24 + elif model_type in ["fantasy"]: + fps = 23 + elif model_type in ["ltxv_13B", "ltxv_13B_distilled"]: + fps = 30 + else: + fps = 16 + return fps + +def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): + if force_fps == "control" and video_guide != None: + fps, _, _, _ = get_video_info(video_guide) + elif force_fps == "source" and video_source != None: + fps, _, _, _ = get_video_info(video_source) + elif len(force_fps) > 0 and is_integer(force_fps) : + fps = int(force_fps) + else: + fps = get_model_fps(base_model_type) + return fps def get_model_name(model_type, description_container = [""]): finetune_def = get_model_finetune_def(model_type) @@ -1703,11 +1815,6 @@ def get_model_name(model_type, description_container = [""]): model_name = "ReCamMaster" model_name += " 14B" if "14B" in model_filename else " 1.3B" description = "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)" - elif "FLF2V" in model_filename: - model_name = "Wan2.1 FLF2V" - model_name += " 720p" if "720p" in model_filename else " 480p" - model_name += " 14B" - description = "The First Last Frame 2 Video model is the official model Image 2 Video model that support Start and End frames." elif "sky_reels2_diffusion_forcing" in model_filename: model_name = "SkyReels2 Diffusion Forcing" if "720p" in model_filename : @@ -1717,20 +1824,13 @@ def get_model_name(model_type, description_container = [""]): model_name += " 14B" if "14B" in model_filename else " 1.3B" description = "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video." elif "phantom" in model_filename: - model_name = "Wan2.1 Phantom" + model_name = "Phantom" if "14B" in model_filename: model_name += " 14B" description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It seems to produce better results if you keep the original background of the Image Referendes." else: model_name += " 1.3B" description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nice results when used at 720p." - elif "fantasy" in model_filename: - model_name = "Wan2.1 Fantasy Speaking 720p" - model_name += " 14B" if "14B" in model_filename else " 1.3B" - description = "The Fantasy Speaking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking extension to process an audio Input." - elif "movii" in model_filename: - model_name = "Wan2.1 MoviiGen 1080p 14B" - description = "MoviiGen 1.1, a cutting-edge video generation model that excels in cinematic aesthetics and visual quality. Use it to generate videos in 720p or 1080p in the 21:9 ratio." elif "ltxv_0.9.7_13B_dev" in model_filename: model_name = "LTX Video 0.9.7 13B" description = "LTX Video is a fast model that can be used to generate long videos (up to 260 frames).It is recommended to keep the number of steps to 30 or you will need to update the file 'ltxv_video/configs/ltxv-13b-0.9.7-dev.yaml'.The LTX Video model expects very long prompts, so don't hesitate to use the Prompt Enhancer." @@ -1763,14 +1863,25 @@ def get_model_name(model_type, description_container = [""]): description_container[0] = description return model_name +def get_model_record(model_name): + return f"WanGP v{WanGP_version} by DeepBeepMeep - " + model_name -def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): - finetune_def = finetunes.get(model_type, None) - if finetune_def != None: - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in finetune_def["URLs"] ] +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]): + if is_module: + choices = modules_files.get(model_type, None) + if choices == None: raise Exception(f"Invalid Module Id '{model_type}'") else: - signature = model_signatures[model_type] - choices = [ name for name in transformer_choices if signature in name] + finetune_def = finetunes.get(model_type, None) + if finetune_def != None: + URLs = finetune_def["URLs"] + if isinstance(URLs, str): + if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}") + return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs]) + else: + choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + else: + signature = model_signatures[model_type] + choices = [ name for name in transformer_choices if signature in name] if len(quantization) == 0: quantization = "bf16" @@ -1816,6 +1927,8 @@ def get_settings_file_name(model_type): def fix_settings(model_type, ui_defaults): video_settings_version = ui_defaults.get("settings_version", 0) + model_type = get_base_model_type(model_type) + prompts = ui_defaults.get("prompts", "") if len(prompts) > 0: ui_defaults["prompt"] = prompts @@ -1823,23 +1936,47 @@ def fix_settings(model_type, ui_defaults): if image_prompt_type != None : if not isinstance(image_prompt_type, str): image_prompt_type = "S" if image_prompt_type == 0 else "SE" - + if model_type == "flf2v_720p" and not "E" in image_prompt_type: + image_prompt_type = "SE" if video_settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type if "lset_name" in ui_defaults: del ui_defaults["lset_name"] - model_type = get_base_model_type(model_type) + if model_type == None: return - + + audio_prompt_type = ui_defaults.get("audio_prompt_type", None) + if video_settings_version < 2.2: + if not model_type in ["vace_1.3B","vace_14B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled"]: + for p in ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]: + if p in ui_defaults: del ui_defaults[p] + + if audio_prompt_type == None : + if any_audio_track(model_type): + audio_prompt_type ="A" + ui_defaults["audio_prompt_type"] = audio_prompt_type + + video_prompt_type = ui_defaults.get("video_prompt_type", "") if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"]: if not "I" in video_prompt_type: # workaround for settings corruption video_prompt_type += "I" if model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") + + + remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 0) + if video_settings_version < 2.21: + if "I" in video_prompt_type: + if remove_background_images_ref == 2: + video_prompt_type = video_prompt_type.replace("I", "KI") + if remove_background_images_ref != 0: + remove_background_images_ref = 1 + ui_defaults["remove_background_images_ref"] = remove_background_images_ref + ui_defaults["video_prompt_type"] = video_prompt_type tea_cache_setting = ui_defaults.get("tea_cache_setting", None) @@ -1883,7 +2020,6 @@ def get_default_settings(model_type): "multi_images_gen_type": 0, "guidance_scale": 5.0, "embedded_guidance_scale" : 6.0, - "audio_guidance_scale": 5.0, "flow_shift": 7.0 if not "720" in model_type and i2v else 5.0, "negative_prompt": "", "activated_loras": [], @@ -1896,13 +2032,24 @@ def get_default_settings(model_type): "slg_start_perc": 10, "slg_end_perc": 90 } + if model_type in ["fantasy"]: + ui_defaults["audio_guidance_scale"] = 5.0 + elif model_type in ["multitalk"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "audio_guidance_scale": 4, + "sliding_window_discard_last_frames" : 4, + "sample_solver" : "euler", + "adaptive_switch" : 1, + }) - if model_type in ["hunyuan","hunyuan_i2v"]: + elif model_type in ["hunyuan","hunyuan_i2v"]: ui_defaults.update({ "guidance_scale": 7.0, }) - if model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]: + elif model_type in ["sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"]: ui_defaults.update({ "guidance_scale": 6.0, "flow_shift": 8, @@ -1915,7 +2062,7 @@ def get_default_settings(model_type): }) - if model_type in ["phantom_1.3B", "phantom_14B"]: + elif model_type in ["phantom_1.3B", "phantom_14B"]: ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, @@ -1952,7 +2099,7 @@ def get_default_settings(model_type): "video_length": 129, "video_prompt_type": "I", }) - elif model_type in ["vace_14B"]: + elif model_type in ["vace_14B", "vace_multitalk_14B"]: ui_defaults.update({ "sliding_window_discard_last_frames": 0, }) @@ -1991,11 +2138,16 @@ for file_path in finetunes_paths: finetune_def["settings"] = json_def finetunes[finetune_id] = finetune_def -model_types += finetunes.keys() +model_types += [model_type for model_type, finetune in finetunes.items() if finetune.get("visible", True)] +displayed_model_types= model_types +# model_types += [model_type for model_type in abstract_model_types if model_type not in finetunes] transformer_types = server_config.get("transformer_types", []) transformer_type = server_config.get("last_model_type", None) +advanced = server_config.get("last_advanced_choice", False) +if args.advanced: advanced = True + if transformer_type != None and not transformer_type in model_types and not transformer_type in finetunes: transformer_type = None if transformer_type == None: transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0] @@ -2060,7 +2212,12 @@ if args.compile: #args.fastest or # compile = "transformer" def save_quantized_model(model, model_type, model_filename, dtype, config_file): - if "quanto" in model_filename: + if "quanto" in model_filename: return + finetune_def = get_model_finetune_def(model_type) + if finetune_def == None: return + URLs= finetune_def["URLs"] + if isinstance(URLs, str): + print("Unable to create a quantized model for a finetune that references external files") return from mmgp import offload if dtype == torch.bfloat16: @@ -2081,21 +2238,18 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file) else: offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file) print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.") - finetune_def = get_model_finetune_def(model_type) - if finetune_def != None: - URLs= finetune_def["URLs"] - if not model_filename in URLs: - URLs.append(model_filename) - finetune_def = finetune_def.copy() - if "settings" in finetune_def: - saved_def = typing.OrderedDict() - saved_def["model"] = finetune_def - saved_def.update(finetune_def["settings"]) - del finetune_def["settings"] - finetune_file = os.path.join("finetunes" , model_type + ".json") - with open(finetune_file, "w", encoding="utf-8") as writer: - writer.write(json.dumps(saved_def, indent=4)) - print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") + if not model_filename in URLs: + URLs.append(model_filename) + finetune_def = finetune_def.copy() + if "settings" in finetune_def: + saved_def = typing.OrderedDict() + saved_def["model"] = finetune_def + saved_def.update(finetune_def["settings"]) + del finetune_def["settings"] + finetune_file = os.path.join("finetunes" , model_type + ".json") + with open(finetune_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(saved_def, indent=4)) + print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") def get_loras_preprocessor(transformer, model_type): preprocessor = getattr(transformer, "preprocess_loras", None) @@ -2177,9 +2331,12 @@ def download_models(model_filename, model_type): shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "" ], - "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], - [ "flownet.pkl" ] ] + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote" "" ], + "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], + ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], + ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], + ["config.json", "pytorch_model.bin", "preprocessor_config.json"], + ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2368,14 +2525,13 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl def load_wan_model(model_filename, model_type, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): if test_class_i2v(base_model_type): cfg = WAN_CONFIGS['i2v-14B'] - model_factory = wan.WanI2V else: cfg = WAN_CONFIGS['t2v-14B'] # cfg = WAN_CONFIGS['t2v-1.3B'] - if base_model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): - model_factory = wan.DTT2V - else: - model_factory = wan.WanT2V + if base_model_type in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"): + model_factory = wan.DTT2V + else: + model_factory = wan.WanAny2V wan_model = model_factory( config=cfg, @@ -2491,9 +2647,9 @@ def load_models(model_type): model_type_list = dependent_models_types + [model_type] new_transformer_filename = model_file_list[-1] for module_type in modules: - model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype)) - model_type_list.append(module_type) + model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True)) + model_type_list.append(module_type) for filename, file_model_type in zip(model_file_list, model_type_list): download_models(filename, file_model_type) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float @@ -2616,7 +2772,7 @@ def apply_changes( state, preload_in_VRAM_choice = 0, depth_anything_v2_variant_choice = "vitl", notification_sound_enabled_choice = 1, - notification_sound_volume_choice = 50 + notification_sound_volume_choice = 50, ): if args.lock_config: return @@ -2647,7 +2803,8 @@ def apply_changes( state, "depth_anything_v2_variant": depth_anything_v2_variant_choice, "notification_sound_enabled" : notification_sound_enabled_choice, "notification_sound_volume" : notification_sound_volume_choice, - "last_model_type" : state["model_type"] + "last_model_type" : state["model_type"], + "last_advanced_choice": state["advanced"], } if Path(server_config_filename).is_file(): @@ -2807,8 +2964,7 @@ def refresh_gallery(state): #, msg params = task["params"] model_type = params["model_type"] model_type = get_base_model_type(model_type) - onemorewindow_visible = model_type in ("vace_1.3B","vace_14B","sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan_custom_edit") - + onemorewindow_visible = test_any_sliding_window(model_type) enhanced = False if prompt.startswith("!enhanced!\n"): enhanced = True @@ -2893,9 +3049,8 @@ def get_file_list(state, input_file_list): for file_path in input_file_list: if isinstance(file_path, tuple): file_path = file_path[0] file_settings, _ = get_settings_from_file(state, file_path, False, False, False) - if file_settings != None: - file_list.append(file_path) - file_settings_list.append(file_settings) + file_list.append(file_path) + file_settings_list.append(file_settings) gen["file_list"] = file_list gen["file_settings_list"] = file_settings_list @@ -2916,89 +3071,123 @@ def select_video(state, input_file_list, event_data: gr.EventData): if len(file_list) > 0: configs = file_settings_list[choice] - from wan.utils.utils import get_video_info, get_file_creation_date file_name = file_list[choice] - fps, width, height, frames_count = get_video_info(file_name) - video_model_name = configs.get("type", "Unknown model") - if "-" in video_model_name: video_model_name = video_model_name[video_model_name.find("-")+2:] - video_prompt = configs.get("prompt", "")[:200] - video_video_prompt_type = configs.get("video_prompt_type", "") - video_image_prompt_type = configs.get("image_prompt_type", "") - map_video_prompt = {"V" : "Control Video", "A" : "Mask Video", "I" : "Reference Images"} - map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} - video_other_prompts = [ v for s,v in map_image_prompt.items() if s in video_image_prompt_type] + [ v for s,v in map_video_prompt.items() if s in video_video_prompt_type] - video_model_type = configs.get("model_type", "t2v") - if any_audio_track(video_model_type): video_other_prompts += ["Audio Source"] - video_other_prompts = ", ".join(video_other_prompts) - video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" - video_length = configs.get("video_length", 0) - original_fps= int(video_length/frames_count*fps) - video_length_summary = f"{video_length} frames" - video_window_no = configs.get("window_no", 0) - if video_window_no > 0: video_length_summary +=f", Window no {video_window_no }" - video_length_summary += " (" - if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, " - video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" - video_seed = configs.get("seed", -1) - video_MMAudio_seed = configs.get("MMAudio_seed", video_seed) - video_guidance_scale = configs.get("video_guidance_scale", 1) - video_embedded_guidance_scale = configs.get("video_embedded_guidance_scale ", 1) - if get_model_family(video_model_type) == "hunyuan": - video_guidance_scale = video_embedded_guidance_scale - video_guidance_label = "Embedded Guidance Scale" - else: - video_guidance_label = "Guidance Scale" - video_flow_shift = configs.get("flow_shift", 1) - video_video_guide_outpainting = configs.get("video_guide_outpainting", "") - video_outpainting = "" - if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#"): - video_video_guide_outpainting = video_video_guide_outpainting.split(" ") - video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%" - video_num_inference_steps = configs.get("num_inference_steps", 0) - video_creation_date = str(get_file_creation_date(file_name)) - if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] - video_generation_time = str(configs.get("generation_time", "0")) + "s" - video_activated_loras = "
".join(configs.get("activated_loras", [])) - video_temporal_upsampling = configs.get("temporal_upsampling", "") - video_spatial_upsampling = configs.get("spatial_upsampling", "") - video_MMAudio_setting = configs.get("MMAudio_setting", 0) - video_MMAudio_prompt = configs.get("MMAudio_prompt", "") - video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "") - values = [video_model_name, video_prompt] - labels = ["Model", "Text Prompt"] - if len(video_other_prompts) >0 : - values += [video_other_prompts] - labels += ["Other Prompts"] - if len(video_outpainting) >0 : - values += [video_outpainting] - labels += ["Outpainting"] - values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] - labels += [ "Resolution", "Video Length", "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"] + fps, width, height, frames_count = get_video_info(file_name) + nb_audio_tracks = extract_audio_tracks(file_name,query_only = True) + values = [ os.path.basename(file_name)] + labels = [ "File Name"] + misc_values= [] + misc_labels = [] + pp_values= [] + pp_labels = [] + if configs != None: + video_model_name = configs.get("type", "Unknown model") + if "-" in video_model_name: video_model_name = video_model_name[video_model_name.find("-")+2:] + misc_values += [video_model_name] + misc_labels += ["Model"] + video_temporal_upsampling = configs.get("temporal_upsampling", "") + video_spatial_upsampling = configs.get("spatial_upsampling", "") + video_MMAudio_setting = configs.get("MMAudio_setting", 0) + video_MMAudio_prompt = configs.get("MMAudio_prompt", "") + video_MMAudio_neg_prompt = configs.get("MMAudio_neg_prompt", "") + if len(video_spatial_upsampling) > 0: + video_temporal_upsampling += " " + video_spatial_upsampling + if len(video_temporal_upsampling) > 0: + pp_values += [ video_temporal_upsampling ] + pp_labels += [ "Upsampling" ] + if video_MMAudio_setting != 0: + pp_values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ] + pp_labels += [ "MMAudio" ] - video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") - video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) - video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) - if len(video_skip_steps_cache_type) > 0: - video_skip_steps_cache = "TeaCache" if video_skip_steps_cache_type == "tea" else "MagCache" - video_skip_steps_cache += f" x{video_skip_steps_multiplier }" - if video_skip_steps_cache_start_step_perc >0: video_skip_steps_cache += f", Start from {video_skip_steps_cache_start_step_perc}%" - values += [ video_skip_steps_cache ] - labels += [ "Skip Steps" ] - if len(video_spatial_upsampling) > 0: - video_temporal_upsampling += " " + video_spatial_upsampling - if len(video_temporal_upsampling) > 0: - values += [ video_temporal_upsampling ] - labels += [ "Upsampling" ] - if video_MMAudio_setting != 0: - values += [ f'Prompt="{video_MMAudio_prompt}", Neg Prompt="{video_MMAudio_neg_prompt}", Seed={video_MMAudio_seed}' ] - labels += [ "MMAudio" ] + if configs == None or not "seed" in configs: + values += misc_values + labels += misc_labels + video_creation_date = str(get_file_creation_date(file_name)) + if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] + values += [f"{width}x{height}", f"{frames_count} frames (duration={frames_count/fps:.1f} s, fps={round(fps)})"] + labels += ["Resolution", "Frames"] + if nb_audio_tracks > 0: + values +=[nb_audio_tracks] + labels +=["Nb Audio Tracks"] - if len(video_activated_loras) > 0: - values += [video_activated_loras] - labels += ["Loras"] - values += [ video_creation_date, video_generation_time ] - labels += [ "Creation Date", "Generation Time" ] + values += pp_values + labels += pp_labels + + values +=[video_creation_date] + labels +=["Creation Date"] + else: + video_prompt = configs.get("prompt", "")[:200] + video_video_prompt_type = configs.get("video_prompt_type", "") + video_image_prompt_type = configs.get("image_prompt_type", "") + video_audio_prompt_type = configs.get("audio_prompt_type", "") + map_video_prompt = {"V" : "Control Video", "A" : "Mask Video", "I" : "Reference Images"} + map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} + map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} + video_other_prompts = [ v for s,v in map_image_prompt.items() if s in video_image_prompt_type] + [ v for s,v in map_video_prompt.items() if s in video_video_prompt_type] + [ v for s,v in map_audio_prompt.items() if s in video_audio_prompt_type] + video_model_type = configs.get("model_type", "t2v") + video_other_prompts = ", ".join(video_other_prompts) + video_resolution = configs.get("resolution", "") + f" (real: {width}x{height})" + video_length = configs.get("video_length", 0) + original_fps= int(video_length/frames_count*fps) + video_length_summary = f"{video_length} frames" + video_window_no = configs.get("window_no", 0) + if video_window_no > 0: video_length_summary +=f", Window no {video_window_no }" + video_length_summary += " (" + if video_length != frames_count: video_length_summary += f"real: {frames_count} frames, " + video_length_summary += f"{frames_count/fps:.1f}s, {round(fps)} fps)" + video_seed = configs.get("seed", -1) + video_MMAudio_seed = configs.get("MMAudio_seed", video_seed) + video_guidance_scale = configs.get("video_guidance_scale", 1) + video_embedded_guidance_scale = configs.get("video_embedded_guidance_scale ", 1) + if get_model_family(video_model_type) == "hunyuan": + video_guidance_scale = video_embedded_guidance_scale + video_guidance_label = "Embedded Guidance Scale" + else: + video_guidance_label = "Guidance Scale" + video_flow_shift = configs.get("flow_shift", 1) + video_video_guide_outpainting = configs.get("video_guide_outpainting", "") + video_outpainting = "" + if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#"): + video_video_guide_outpainting = video_video_guide_outpainting.split(" ") + video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%" + video_num_inference_steps = configs.get("num_inference_steps", 0) + video_creation_date = str(get_file_creation_date(file_name)) + if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] + video_generation_time = str(configs.get("generation_time", "0")) + "s" + video_activated_loras = "
".join(configs.get("activated_loras", [])) + values += misc_values + [video_prompt] + labels += misc_labels + ["Text Prompt"] + if len(video_other_prompts) >0 : + values += [video_other_prompts] + labels += ["Other Prompts"] + if len(video_outpainting) >0 : + values += [video_outpainting] + labels += ["Outpainting"] + values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps] + labels += [ "Resolution", "Video Length", "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"] + + video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") + video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) + video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) + if len(video_skip_steps_cache_type) > 0: + video_skip_steps_cache = "TeaCache" if video_skip_steps_cache_type == "tea" else "MagCache" + video_skip_steps_cache += f" x{video_skip_steps_multiplier }" + if video_skip_steps_cache_start_step_perc >0: video_skip_steps_cache += f", Start from {video_skip_steps_cache_start_step_perc}%" + values += [ video_skip_steps_cache ] + labels += [ "Skip Steps" ] + + values += pp_values + labels += pp_labels + + if len(video_activated_loras) > 0: + values += [video_activated_loras] + labels += ["Loras"] + if nb_audio_tracks > 0: + values +=[nb_audio_tracks] + labels +=["Nb Audio Tracks"] + values += [ video_creation_date, video_generation_time ] + labels += [ "Creation Date", "Generation Time" ] table_style = """