mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
multitalk and more
This commit is contained in:
parent
eb811e0c52
commit
4908c3c243
18
README.md
18
README.md
@ -20,6 +20,24 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||||
|
|
||||||
## 🔥 Latest Updates
|
## 🔥 Latest Updates
|
||||||
|
### July 6 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
|
||||||
|
**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk an for very a long time (which **Infinite** amount of time in the field of video generation).
|
||||||
|
|
||||||
|
Of course you will get as well plain *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
|
||||||
|
|
||||||
|
And since I am mister nice guy I had enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
|
||||||
|
|
||||||
|
As I feel like a resting a bit I haven't produced a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your Master Pieces. The best one will be added to the *Announcements Channel* and will bring eternal fame to its author.
|
||||||
|
|
||||||
|
But wait, there is more:
|
||||||
|
- Sliding Windows support has been added anywhere with Wan models, so imagine now with text2video upgraded in 6.5 into a video2video, you can upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
|
||||||
|
- I have added also the capability to transfer the audio of the original control video and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio and the original pace. Be aware that the duration will be limited 1000 frames as I still need to add streaming support for unlimited video sizes.
|
||||||
|
|
||||||
|
Also, of interest too:
|
||||||
|
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
|
||||||
|
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
|
||||||
|
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
|
||||||
|
|
||||||
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
|
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
|
||||||
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
|
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
|
||||||
- In one click use the newly generated video as a Control Video or Source Video to be continued
|
- In one click use the newly generated video as a Control Video or Source Video to be continued
|
||||||
|
|||||||
@ -10,5 +10,6 @@
|
|||||||
"num_heads": 40,
|
"num_heads": 40,
|
||||||
"num_layers": 40,
|
"num_layers": 40,
|
||||||
"out_dim": 16,
|
"out_dim": 16,
|
||||||
"text_len": 512
|
"text_len": 512,
|
||||||
|
"flf": true
|
||||||
}
|
}
|
||||||
|
|||||||
15
configs/multitalk.json
Normal file
15
configs/multitalk.json
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "WanModel",
|
||||||
|
"_diffusers_version": "0.30.0",
|
||||||
|
"dim": 5120,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": 36,
|
||||||
|
"model_type": "i2v",
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"multitalk_output_dim": 768
|
||||||
|
}
|
||||||
17
configs/vace_multitalk_14B.json
Normal file
17
configs/vace_multitalk_14B.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "VaceWanModel",
|
||||||
|
"_diffusers_version": "0.30.0",
|
||||||
|
"dim": 5120,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": 16,
|
||||||
|
"model_type": "t2v",
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||||
|
"vace_in_dim": 96,
|
||||||
|
"multitalk_output_dim": 768
|
||||||
|
}
|
||||||
12
finetunes/fantasy.json
Normal file
12
finetunes/fantasy.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"name": "Fantasy Talking 720p",
|
||||||
|
"architecture" : "fantasy",
|
||||||
|
"modules": ["fantasy"],
|
||||||
|
"description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.",
|
||||||
|
"URLs": "i2v_720p",
|
||||||
|
"teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
||||||
|
},
|
||||||
|
"resolution": "1280x720"
|
||||||
|
}
|
||||||
16
finetunes/flf2v_720p.json
Normal file
16
finetunes/flf2v_720p.json
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"name": "First Last Frame to Video 720p (FLF2V)14B",
|
||||||
|
"architecture" : "flf2v_720p",
|
||||||
|
"visible" : false,
|
||||||
|
"description": "The First Last Frame 2 Video model is the official model Image 2 Video model that supports Start and End frames.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_bf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"auto_quantize": true
|
||||||
|
},
|
||||||
|
"resolution": "1280x720"
|
||||||
|
}
|
||||||
16
finetunes/moviigen.json
Normal file
16
finetunes/moviigen.json
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"name": "MoviiGen 1080p 14B",
|
||||||
|
"architecture" : "t2v",
|
||||||
|
"description": "MoviiGen 1.1, a cutting-edge video generation model that excels in cinematic aesthetics and visual quality. Use it to generate videos in 720p or 1080p in the 21:9 ratio.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_mbf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_moviigen1.1_14B_quanto_mfp16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"auto_quantize": true
|
||||||
|
},
|
||||||
|
"resolution": "1280x720",
|
||||||
|
"video_length": 81
|
||||||
|
}
|
||||||
11
finetunes/multitalk.json
Normal file
11
finetunes/multitalk.json
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"name": "Multitalk 480p",
|
||||||
|
"architecture" : "multitalk",
|
||||||
|
"modules": ["multitalk"],
|
||||||
|
"description": "The Multitalk model corresponds to the original Wan image 2 video model combined with the Multitalk module. It lets you have up to two people have a conversation.",
|
||||||
|
"URLs": "i2v",
|
||||||
|
"teacache_coefficients" : [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
||||||
|
}
|
||||||
|
}
|
||||||
13
finetunes/multitalk_720p.json
Normal file
13
finetunes/multitalk_720p.json
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"model":
|
||||||
|
{
|
||||||
|
"name": "Multitalk 720p",
|
||||||
|
"architecture" : "multitalk",
|
||||||
|
"modules": ["multitalk"],
|
||||||
|
"description": "The Multitalk model corresponds to the original Wan image 2 video 720p model combined with the Multitalk module. It lets you have up to two people have a conversation.",
|
||||||
|
"URLs": "i2v_720p",
|
||||||
|
"teacache_coefficients" : [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683],
|
||||||
|
"auto_quantize": true
|
||||||
|
},
|
||||||
|
"resolution": "1280x720"
|
||||||
|
}
|
||||||
11
finetunes/vace_14B.json
Normal file
11
finetunes/vace_14B.json
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Vace ControlNet 14B",
|
||||||
|
"architecture": "vace_14B",
|
||||||
|
"modules": [
|
||||||
|
"vace_14B"
|
||||||
|
],
|
||||||
|
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
|
||||||
|
"URLs": "t2v"
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -6,12 +6,7 @@
|
|||||||
"vace_14B"
|
"vace_14B"
|
||||||
],
|
],
|
||||||
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
||||||
"URLs": [
|
"URLs": "t2v_fusionix"
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
|
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
|
|
||||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
|
|
||||||
],
|
|
||||||
"auto_quantize": true
|
|
||||||
},
|
},
|
||||||
"negative_prompt": "",
|
"negative_prompt": "",
|
||||||
"prompt": "",
|
"prompt": "",
|
||||||
|
|||||||
41
finetunes/vace_multitalk_14B.json
Normal file
41
finetunes/vace_multitalk_14B.json
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Vace Multitalk FusioniX 14B",
|
||||||
|
"architecture": "vace_multitalk_14B",
|
||||||
|
"modules": [
|
||||||
|
"vace_14B",
|
||||||
|
"multitalk"
|
||||||
|
],
|
||||||
|
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail. And it that's not sufficient Vace is combined with Multitalk.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"auto_quantize": true
|
||||||
|
},
|
||||||
|
"negative_prompt": "",
|
||||||
|
"prompt": "",
|
||||||
|
"resolution": "832x480",
|
||||||
|
"video_length": 81,
|
||||||
|
"seed": -1,
|
||||||
|
"num_inference_steps": 10,
|
||||||
|
"guidance_scale": 1,
|
||||||
|
"flow_shift": 5,
|
||||||
|
"embedded_guidance_scale": 6,
|
||||||
|
"repeat_generation": 1,
|
||||||
|
"multi_images_gen_type": 0,
|
||||||
|
"tea_cache_setting": 0,
|
||||||
|
"tea_cache_start_step_perc": 0,
|
||||||
|
"loras_multipliers": "",
|
||||||
|
"temporal_upsampling": "",
|
||||||
|
"spatial_upsampling": "",
|
||||||
|
"RIFLEx_setting": 0,
|
||||||
|
"slg_switch": 0,
|
||||||
|
"slg_start_perc": 10,
|
||||||
|
"slg_end_perc": 90,
|
||||||
|
"cfg_star_switch": 0,
|
||||||
|
"cfg_zero_step": -1,
|
||||||
|
"prompt_enhancer": "",
|
||||||
|
"activated_loras": []
|
||||||
|
}
|
||||||
@ -131,24 +131,14 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
||||||
"""Remux video with new audio using FFmpeg."""
|
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
|
||||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||||
temp_path = Path(f.name)
|
temp_path = Path(f.name)
|
||||||
|
temp_path_str= str(temp_path)
|
||||||
try:
|
import torchaudio
|
||||||
# Write audio as WAV
|
torchaudio.save(temp_path_str, audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate)
|
||||||
import torchaudio
|
combine_video_with_audio_tracks(video_path, [temp_path_str], output_path )
|
||||||
torchaudio.save(str(temp_path), audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate)
|
temp_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
# Remux with FFmpeg
|
|
||||||
subprocess.run([
|
|
||||||
'ffmpeg', '-i', str(video_path), '-i', str(temp_path),
|
|
||||||
'-c:v', 'copy', '-c:a', 'aac', '-map', '0:v', '-map', '1:a',
|
|
||||||
'-shortest', '-y', str(output_path)
|
|
||||||
], check=True, capture_output=True)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
temp_path.unlink(missing_ok=True)
|
|
||||||
|
|
||||||
def remux_with_audio_old(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
def remux_with_audio_old(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import os
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from PIL import Image, ImageDraw, ImageOps
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import PIL
|
import PIL
|
||||||
from .mask_painter import mask_painter as mask_painter2
|
from .mask_painter import mask_painter as mask_painter2
|
||||||
|
|||||||
922
preprocessing/speakers_separator.py
Normal file
922
preprocessing/speakers_separator.py
Normal file
@ -0,0 +1,922 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
import argparse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
|
||||||
|
verbose_output = True
|
||||||
|
|
||||||
|
# Suppress specific warnings before importing pyannote
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning, module="pyannote.audio.models.blocks.pooling")
|
||||||
|
warnings.filterwarnings("ignore", message=".*TensorFloat-32.*", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", message=".*std\\(\\): degrees of freedom.*", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", message=".*speechbrain.pretrained.*was deprecated.*", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", message=".*Module 'speechbrain.pretrained'.*", category=UserWarning)
|
||||||
|
# logging.getLogger('speechbrain').setLevel(logging.WARNING)
|
||||||
|
# logging.getLogger('speechbrain.utils.checkpoints').setLevel(logging.WARNING)
|
||||||
|
os.environ["SB_LOG_LEVEL"] = "WARNING"
|
||||||
|
import speechbrain
|
||||||
|
|
||||||
|
def xprint(t = None):
|
||||||
|
if verbose_output:
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
# Configure TF32 before any CUDA operations to avoid reproducibility warnings
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
PYANNOTE_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PYANNOTE_AVAILABLE = False
|
||||||
|
print("Install: pip install pyannote.audio")
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedPyannote31SpeakerSeparator:
|
||||||
|
def __init__(self, hf_token: str = None, local_model_path: str = None,
|
||||||
|
vad_onset: float = 0.2, vad_offset: float = 0.8):
|
||||||
|
"""
|
||||||
|
Initialize with Pyannote 3.1 pipeline with tunable VAD sensitivity.
|
||||||
|
"""
|
||||||
|
embedding_path = "ckpts/pyannote/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin"
|
||||||
|
segmentation_path = "ckpts/pyannote/pytorch_model_segmentation-3.0.bin"
|
||||||
|
|
||||||
|
|
||||||
|
xprint(f"Loading segmentation model from: {segmentation_path}")
|
||||||
|
xprint(f"Loading embedding model from: {embedding_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyannote.audio import Model
|
||||||
|
from pyannote.audio.pipelines import SpeakerDiarization
|
||||||
|
|
||||||
|
# Load models directly
|
||||||
|
segmentation_model = Model.from_pretrained(segmentation_path)
|
||||||
|
embedding_model = Model.from_pretrained(embedding_path)
|
||||||
|
xprint("Models loaded successfully!")
|
||||||
|
|
||||||
|
# Create pipeline manually
|
||||||
|
self.pipeline = SpeakerDiarization(
|
||||||
|
segmentation=segmentation_model,
|
||||||
|
embedding=embedding_model,
|
||||||
|
clustering='AgglomerativeClustering'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Instantiate with default parameters
|
||||||
|
self.pipeline.instantiate({
|
||||||
|
'clustering': {
|
||||||
|
'method': 'centroid',
|
||||||
|
'min_cluster_size': 12,
|
||||||
|
'threshold': 0.7045654963945799
|
||||||
|
},
|
||||||
|
'segmentation': {
|
||||||
|
'min_duration_off': 0.0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
xprint("Pipeline instantiated successfully!")
|
||||||
|
|
||||||
|
# Send to GPU if available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
xprint("CUDA available, moving pipeline to GPU...")
|
||||||
|
self.pipeline.to(torch.device("cuda"))
|
||||||
|
else:
|
||||||
|
xprint("CUDA not available, using CPU...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"Error loading pipeline: {e}")
|
||||||
|
xprint(f"Error type: {type(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.xprint_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
self.hf_token = hf_token
|
||||||
|
self._overlap_pipeline = None
|
||||||
|
|
||||||
|
def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]:
|
||||||
|
"""Optimized main separation function with memory management."""
|
||||||
|
xprint("Starting optimized audio separation...")
|
||||||
|
self._current_audio_path = os.path.abspath(audio_path)
|
||||||
|
|
||||||
|
# Suppress warnings during processing
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
# Load audio
|
||||||
|
waveform, sample_rate = self.load_audio(audio_path)
|
||||||
|
|
||||||
|
# Perform diarization
|
||||||
|
diarization = self.perform_optimized_diarization(audio_path)
|
||||||
|
|
||||||
|
# Create masks
|
||||||
|
masks = self.create_optimized_speaker_masks(diarization, waveform.shape[1], sample_rate)
|
||||||
|
|
||||||
|
# Apply background preservation
|
||||||
|
final_masks = self.apply_optimized_background_preservation(masks, waveform.shape[1])
|
||||||
|
|
||||||
|
# Clear intermediate results
|
||||||
|
del masks
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Save outputs efficiently
|
||||||
|
output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2)
|
||||||
|
|
||||||
|
return output_paths
|
||||||
|
|
||||||
|
def _extract_both_speaking_regions(
|
||||||
|
self,
|
||||||
|
diarization,
|
||||||
|
audio_length: int,
|
||||||
|
sample_rate: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Detect regions where ≥2 speakers talk simultaneously
|
||||||
|
using pyannote/overlapped-speech-detection.
|
||||||
|
Falls back to manual pair-wise detection if the model
|
||||||
|
is unavailable.
|
||||||
|
"""
|
||||||
|
xprint("Extracting overlap with dedicated pipeline…")
|
||||||
|
both_speaking_mask = np.zeros(audio_length, dtype=bool)
|
||||||
|
|
||||||
|
# ── 1) try the proper overlap model ────────────────────────────────
|
||||||
|
overlap_pipeline = self._get_overlap_pipeline()
|
||||||
|
|
||||||
|
# try the path stored by separate_audio – otherwise whatever the
|
||||||
|
# diarization object carries (may be None)
|
||||||
|
audio_uri = getattr(self, "_current_audio_path", None) \
|
||||||
|
or getattr(diarization, "uri", None)
|
||||||
|
if overlap_pipeline and audio_uri:
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
overlap_annotation = overlap_pipeline(audio_uri)
|
||||||
|
|
||||||
|
for seg in overlap_annotation.get_timeline().support():
|
||||||
|
s = max(0, int(seg.start * sample_rate))
|
||||||
|
e = min(audio_length, int(seg.end * sample_rate))
|
||||||
|
if s < e:
|
||||||
|
both_speaking_mask[s:e] = True
|
||||||
|
t = np.sum(both_speaking_mask) / sample_rate
|
||||||
|
xprint(f" Found {t:.1f}s of overlapped speech (model) ")
|
||||||
|
return both_speaking_mask
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f" ⚠ Overlap model failed: {e}")
|
||||||
|
|
||||||
|
# ── 2) fallback = brute-force pairwise intersection ────────────────
|
||||||
|
xprint(" Falling back to manual overlap detection…")
|
||||||
|
timeline_tracks = list(diarization.itertracks(yield_label=True))
|
||||||
|
for i, (turn1, _, spk1) in enumerate(timeline_tracks):
|
||||||
|
for j, (turn2, _, spk2) in enumerate(timeline_tracks):
|
||||||
|
if i >= j or spk1 == spk2:
|
||||||
|
continue
|
||||||
|
o_start, o_end = max(turn1.start, turn2.start), min(turn1.end, turn2.end)
|
||||||
|
if o_start < o_end:
|
||||||
|
s = max(0, int(o_start * sample_rate))
|
||||||
|
e = min(audio_length, int(o_end * sample_rate))
|
||||||
|
if s < e:
|
||||||
|
both_speaking_mask[s:e] = True
|
||||||
|
t = np.sum(both_speaking_mask) / sample_rate
|
||||||
|
xprint(f" Found {t:.1f}s of overlapped speech (manual) ")
|
||||||
|
return both_speaking_mask
|
||||||
|
|
||||||
|
def _configure_vad(self, vad_onset: float, vad_offset: float):
|
||||||
|
"""Configure VAD parameters efficiently."""
|
||||||
|
xprint("Applying more sensitive VAD parameters...")
|
||||||
|
try:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
if hasattr(self.pipeline, '_vad'):
|
||||||
|
self.pipeline._vad.instantiate({
|
||||||
|
"onset": vad_onset,
|
||||||
|
"offset": vad_offset,
|
||||||
|
"min_duration_on": 0.1,
|
||||||
|
"min_duration_off": 0.1,
|
||||||
|
"pad_onset": 0.1,
|
||||||
|
"pad_offset": 0.1,
|
||||||
|
})
|
||||||
|
xprint(f"✓ VAD parameters updated: onset={vad_onset}, offset={vad_offset}")
|
||||||
|
else:
|
||||||
|
xprint("⚠ Could not access VAD component directly")
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"⚠ Could not modify VAD parameters: {e}")
|
||||||
|
|
||||||
|
def _get_overlap_pipeline(self):
|
||||||
|
"""
|
||||||
|
Build a pyannote-3-native OverlappedSpeechDetection pipeline.
|
||||||
|
|
||||||
|
• uses the open-licence `pyannote/segmentation-3.0` checkpoint
|
||||||
|
• only `min_duration_on/off` can be tuned (API 3.x)
|
||||||
|
"""
|
||||||
|
if self._overlap_pipeline is not None:
|
||||||
|
return None if self._overlap_pipeline is False else self._overlap_pipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyannote.audio.pipelines import OverlappedSpeechDetection
|
||||||
|
|
||||||
|
xprint("Building OverlappedSpeechDetection with segmentation-3.0…")
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
# 1) constructor → segmentation model ONLY
|
||||||
|
ods = OverlappedSpeechDetection(
|
||||||
|
segmentation="pyannote/segmentation-3.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) instantiate → **single dict** with the two valid knobs
|
||||||
|
ods.instantiate({
|
||||||
|
"min_duration_on": 0.06, # ≈ your previous 0.055 s
|
||||||
|
"min_duration_off": 0.10, # ≈ your previous 0.098 s
|
||||||
|
})
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
ods.to(torch.device("cuda"))
|
||||||
|
|
||||||
|
self._overlap_pipeline = ods
|
||||||
|
xprint("✓ Overlap pipeline ready (segmentation-3.0)")
|
||||||
|
return ods
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"⚠ Could not build overlap pipeline ({e}). "
|
||||||
|
"Falling back to manual pair-wise detection.")
|
||||||
|
self._overlap_pipeline = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _xprint_setup_instructions(self):
|
||||||
|
"""xprint setup instructions."""
|
||||||
|
xprint("\nTo use Pyannote 3.1:")
|
||||||
|
xprint("1. Get token: https://huggingface.co/settings/tokens")
|
||||||
|
xprint("2. Accept terms: https://huggingface.co/pyannote/speaker-diarization-3.1")
|
||||||
|
xprint("3. Run with: --token YOUR_TOKEN")
|
||||||
|
|
||||||
|
def load_audio(self, audio_path: str) -> Tuple[torch.Tensor, int]:
|
||||||
|
"""Load and preprocess audio efficiently."""
|
||||||
|
xprint(f"Loading audio: {audio_path}")
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
|
||||||
|
# Convert to mono efficiently
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
||||||
|
|
||||||
|
xprint(f"Audio: {waveform.shape[1]} samples at {sample_rate}Hz")
|
||||||
|
return waveform, sample_rate
|
||||||
|
|
||||||
|
def perform_optimized_diarization(self, audio_path: str) -> object:
|
||||||
|
"""
|
||||||
|
Optimized diarization with efficient parameter testing.
|
||||||
|
"""
|
||||||
|
xprint("Running optimized Pyannote 3.1 diarization...")
|
||||||
|
|
||||||
|
# Optimized strategy order - most likely to succeed first
|
||||||
|
strategies = [
|
||||||
|
{"min_speakers": 2, "max_speakers": 2}, # Most common case
|
||||||
|
{"num_speakers": 2}, # Direct specification
|
||||||
|
{"min_speakers": 2, "max_speakers": 3}, # Slight flexibility
|
||||||
|
{"min_speakers": 1, "max_speakers": 2}, # Fallback
|
||||||
|
{"min_speakers": 2, "max_speakers": 4}, # More flexibility
|
||||||
|
{} # No constraints
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, params in enumerate(strategies):
|
||||||
|
try:
|
||||||
|
xprint(f"Strategy {i+1}: {params}")
|
||||||
|
|
||||||
|
# Clear GPU memory before each attempt
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
diarization = self.pipeline(audio_path, **params)
|
||||||
|
|
||||||
|
speakers = list(diarization.labels())
|
||||||
|
speaker_count = len(speakers)
|
||||||
|
|
||||||
|
xprint(f" → Detected {speaker_count} speakers: {speakers}")
|
||||||
|
|
||||||
|
# Accept first successful result with 2+ speakers
|
||||||
|
if speaker_count >= 2:
|
||||||
|
xprint(f"✓ Success with strategy {i+1}! Using {speaker_count} speakers")
|
||||||
|
return diarization
|
||||||
|
elif speaker_count == 1 and i == 0:
|
||||||
|
# Store first result as fallback
|
||||||
|
fallback_diarization = diarization
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f" Strategy {i+1} failed: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If we only got 1 speaker, try one aggressive attempt
|
||||||
|
if 'fallback_diarization' in locals():
|
||||||
|
xprint("Attempting aggressive clustering for single speaker...")
|
||||||
|
try:
|
||||||
|
aggressive_diarization = self._try_aggressive_clustering(audio_path)
|
||||||
|
if aggressive_diarization and len(list(aggressive_diarization.labels())) >= 2:
|
||||||
|
return aggressive_diarization
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"Aggressive clustering failed: {e}")
|
||||||
|
|
||||||
|
xprint("Using single speaker result")
|
||||||
|
return fallback_diarization
|
||||||
|
|
||||||
|
# Last resort - run without constraints
|
||||||
|
xprint("Last resort: running without constraints...")
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
return self.pipeline(audio_path)
|
||||||
|
|
||||||
|
def _try_aggressive_clustering(self, audio_path: str) -> object:
|
||||||
|
"""Try aggressive clustering parameters."""
|
||||||
|
try:
|
||||||
|
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
# Create aggressive pipeline
|
||||||
|
temp_pipeline = SpeakerDiarization(
|
||||||
|
segmentation=self.pipeline.segmentation,
|
||||||
|
embedding=self.pipeline.embedding,
|
||||||
|
clustering="AgglomerativeClustering"
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_pipeline.instantiate({
|
||||||
|
"clustering": {
|
||||||
|
"method": "centroid",
|
||||||
|
"min_cluster_size": 1,
|
||||||
|
"threshold": 0.1,
|
||||||
|
},
|
||||||
|
"segmentation": {
|
||||||
|
"min_duration_off": 0.0,
|
||||||
|
"min_duration_on": 0.1,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return temp_pipeline(audio_path, min_speakers=2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"Aggressive clustering setup failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def create_optimized_speaker_masks(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Optimized mask creation using vectorized operations."""
|
||||||
|
xprint("Creating optimized speaker masks...")
|
||||||
|
|
||||||
|
speakers = list(diarization.labels())
|
||||||
|
xprint(f"Processing speakers: {speakers}")
|
||||||
|
|
||||||
|
# Handle edge cases
|
||||||
|
if len(speakers) == 0:
|
||||||
|
xprint("⚠ No speakers detected, creating dummy masks")
|
||||||
|
return self._create_dummy_masks(audio_length)
|
||||||
|
|
||||||
|
if len(speakers) == 1:
|
||||||
|
xprint("⚠ Only 1 speaker detected, creating temporal split")
|
||||||
|
return self._create_optimized_temporal_split(diarization, audio_length, sample_rate)
|
||||||
|
|
||||||
|
# Extract both-speaking regions from diarization timeline
|
||||||
|
both_speaking_regions = self._extract_both_speaking_regions(diarization, audio_length, sample_rate)
|
||||||
|
|
||||||
|
# Optimized mask creation for multiple speakers
|
||||||
|
masks = {}
|
||||||
|
|
||||||
|
# Batch process all speakers
|
||||||
|
for speaker in speakers:
|
||||||
|
# Get all segments for this speaker at once
|
||||||
|
segments = []
|
||||||
|
speaker_timeline = diarization.label_timeline(speaker)
|
||||||
|
for segment in speaker_timeline:
|
||||||
|
start_sample = max(0, int(segment.start * sample_rate))
|
||||||
|
end_sample = min(audio_length, int(segment.end * sample_rate))
|
||||||
|
if start_sample < end_sample:
|
||||||
|
segments.append((start_sample, end_sample))
|
||||||
|
|
||||||
|
# Vectorized mask creation
|
||||||
|
if segments:
|
||||||
|
mask = self._create_mask_vectorized(segments, audio_length)
|
||||||
|
masks[speaker] = mask
|
||||||
|
speaking_time = np.sum(mask) / sample_rate
|
||||||
|
xprint(f" {speaker}: {speaking_time:.1f}s speaking time")
|
||||||
|
else:
|
||||||
|
masks[speaker] = np.zeros(audio_length, dtype=np.float32)
|
||||||
|
|
||||||
|
# Store both-speaking info for later use
|
||||||
|
self._both_speaking_regions = both_speaking_regions
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _create_mask_vectorized(self, segments: List[Tuple[int, int]], audio_length: int) -> np.ndarray:
|
||||||
|
"""Create mask using vectorized operations."""
|
||||||
|
mask = np.zeros(audio_length, dtype=np.float32)
|
||||||
|
|
||||||
|
if not segments:
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# Convert segments to arrays for vectorized operations
|
||||||
|
segments_array = np.array(segments)
|
||||||
|
starts = segments_array[:, 0]
|
||||||
|
ends = segments_array[:, 1]
|
||||||
|
|
||||||
|
# Use advanced indexing for bulk assignment
|
||||||
|
for start, end in zip(starts, ends):
|
||||||
|
mask[start:end] = 1.0
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _create_dummy_masks(self, audio_length: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Create dummy masks for edge cases."""
|
||||||
|
return {
|
||||||
|
"SPEAKER_00": np.ones(audio_length, dtype=np.float32) * 0.5,
|
||||||
|
"SPEAKER_01": np.ones(audio_length, dtype=np.float32) * 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_optimized_temporal_split(self, diarization, audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Optimized temporal split with vectorized operations."""
|
||||||
|
xprint("Creating optimized temporal split...")
|
||||||
|
|
||||||
|
# Extract all segments at once
|
||||||
|
segments = []
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
segments.append((turn.start, turn.end))
|
||||||
|
|
||||||
|
segments.sort()
|
||||||
|
xprint(f"Found {len(segments)} speech segments")
|
||||||
|
|
||||||
|
if len(segments) <= 1:
|
||||||
|
# Single segment or no segments - simple split
|
||||||
|
return self._create_simple_split(audio_length)
|
||||||
|
|
||||||
|
# Vectorized gap analysis
|
||||||
|
segment_array = np.array(segments)
|
||||||
|
gaps = segment_array[1:, 0] - segment_array[:-1, 1] # Vectorized gap calculation
|
||||||
|
|
||||||
|
if len(gaps) > 0:
|
||||||
|
longest_gap_idx = np.argmax(gaps)
|
||||||
|
longest_gap_duration = gaps[longest_gap_idx]
|
||||||
|
|
||||||
|
xprint(f"Longest gap: {longest_gap_duration:.1f}s after segment {longest_gap_idx+1}")
|
||||||
|
|
||||||
|
if longest_gap_duration > 1.0:
|
||||||
|
# Split at natural break
|
||||||
|
split_point = longest_gap_idx + 1
|
||||||
|
xprint(f"Splitting at natural break: segments 1-{split_point} vs {split_point+1}-{len(segments)}")
|
||||||
|
|
||||||
|
return self._create_split_masks(segments, split_point, audio_length, sample_rate)
|
||||||
|
|
||||||
|
# Fallback: alternating assignment
|
||||||
|
xprint("Using alternating assignment...")
|
||||||
|
return self._create_alternating_masks(segments, audio_length, sample_rate)
|
||||||
|
|
||||||
|
def _create_simple_split(self, audio_length: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Simple temporal split in half."""
|
||||||
|
mid_point = audio_length // 2
|
||||||
|
masks = {
|
||||||
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
||||||
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
||||||
|
}
|
||||||
|
masks["SPEAKER_00"][:mid_point] = 1.0
|
||||||
|
masks["SPEAKER_01"][mid_point:] = 1.0
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _create_split_masks(self, segments: List[Tuple[float, float]], split_point: int,
|
||||||
|
audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Create masks with split at specific point."""
|
||||||
|
masks = {
|
||||||
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
||||||
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Vectorized segment processing
|
||||||
|
for i, (start_time, end_time) in enumerate(segments):
|
||||||
|
start_sample = max(0, int(start_time * sample_rate))
|
||||||
|
end_sample = min(audio_length, int(end_time * sample_rate))
|
||||||
|
|
||||||
|
if start_sample < end_sample:
|
||||||
|
speaker_key = "SPEAKER_00" if i < split_point else "SPEAKER_01"
|
||||||
|
masks[speaker_key][start_sample:end_sample] = 1.0
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def _create_alternating_masks(self, segments: List[Tuple[float, float]],
|
||||||
|
audio_length: int, sample_rate: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""Create masks with alternating assignment."""
|
||||||
|
masks = {
|
||||||
|
"SPEAKER_00": np.zeros(audio_length, dtype=np.float32),
|
||||||
|
"SPEAKER_01": np.zeros(audio_length, dtype=np.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, (start_time, end_time) in enumerate(segments):
|
||||||
|
start_sample = max(0, int(start_time * sample_rate))
|
||||||
|
end_sample = min(audio_length, int(end_time * sample_rate))
|
||||||
|
|
||||||
|
if start_sample < end_sample:
|
||||||
|
speaker_key = f"SPEAKER_0{i % 2}"
|
||||||
|
masks[speaker_key][start_sample:end_sample] = 1.0
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
def apply_optimized_background_preservation(self, masks: Dict[str, np.ndarray],
|
||||||
|
audio_length: int) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Heavily optimized background preservation using pure vectorized operations.
|
||||||
|
"""
|
||||||
|
xprint("Applying optimized voice separation logic...")
|
||||||
|
|
||||||
|
# Ensure exactly 2 speakers
|
||||||
|
speaker_keys = self._get_top_speakers(masks, audio_length)
|
||||||
|
|
||||||
|
# Pre-allocate final masks
|
||||||
|
final_masks = {
|
||||||
|
speaker: np.zeros(audio_length, dtype=np.float32)
|
||||||
|
for speaker in speaker_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get active masks (vectorized)
|
||||||
|
active_0 = masks.get(speaker_keys[0], np.zeros(audio_length)) > 0.5
|
||||||
|
active_1 = masks.get(speaker_keys[1], np.zeros(audio_length)) > 0.5
|
||||||
|
|
||||||
|
# Vectorized mask assignment
|
||||||
|
both_active = active_0 & active_1
|
||||||
|
only_0 = active_0 & ~active_1
|
||||||
|
only_1 = ~active_0 & active_1
|
||||||
|
neither = ~active_0 & ~active_1
|
||||||
|
|
||||||
|
# Apply assignments (all vectorized)
|
||||||
|
final_masks[speaker_keys[0]][both_active] = 1.0
|
||||||
|
final_masks[speaker_keys[1]][both_active] = 1.0
|
||||||
|
|
||||||
|
final_masks[speaker_keys[0]][only_0] = 1.0
|
||||||
|
final_masks[speaker_keys[1]][only_0] = 0.0
|
||||||
|
|
||||||
|
final_masks[speaker_keys[0]][only_1] = 0.0
|
||||||
|
final_masks[speaker_keys[1]][only_1] = 1.0
|
||||||
|
|
||||||
|
# Handle ambiguous regions efficiently
|
||||||
|
if np.any(neither):
|
||||||
|
ambiguous_assignments = self._compute_ambiguous_assignments_vectorized(
|
||||||
|
masks, speaker_keys, neither, audio_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply ambiguous assignments
|
||||||
|
final_masks[speaker_keys[0]][neither] = (ambiguous_assignments == 0).astype(np.float32) * 0.5
|
||||||
|
final_masks[speaker_keys[1]][neither] = (ambiguous_assignments == 1).astype(np.float32) * 0.5
|
||||||
|
|
||||||
|
# xprint statistics (vectorized)
|
||||||
|
sample_rate = 16000 # Assume 16kHz for timing
|
||||||
|
xprint(f" Both speaking clearly: {np.sum(both_active)/sample_rate:.1f}s")
|
||||||
|
xprint(f" {speaker_keys[0]} only: {np.sum(only_0)/sample_rate:.1f}s")
|
||||||
|
xprint(f" {speaker_keys[1]} only: {np.sum(only_1)/sample_rate:.1f}s")
|
||||||
|
xprint(f" Ambiguous (assigned): {np.sum(neither)/sample_rate:.1f}s")
|
||||||
|
|
||||||
|
# Apply minimum duration smoothing to prevent rapid switching
|
||||||
|
final_masks = self._apply_minimum_duration_smoothing(final_masks, sample_rate)
|
||||||
|
|
||||||
|
return final_masks
|
||||||
|
|
||||||
|
def _get_top_speakers(self, masks: Dict[str, np.ndarray], audio_length: int) -> List[str]:
|
||||||
|
"""Get top 2 speakers by speaking time."""
|
||||||
|
speaker_keys = list(masks.keys())
|
||||||
|
|
||||||
|
if len(speaker_keys) > 2:
|
||||||
|
# Vectorized speaking time calculation
|
||||||
|
speaking_times = {k: np.sum(v) for k, v in masks.items()}
|
||||||
|
speaker_keys = sorted(speaking_times.keys(), key=lambda x: speaking_times[x], reverse=True)[:2]
|
||||||
|
xprint(f"Keeping top 2 speakers: {speaker_keys}")
|
||||||
|
elif len(speaker_keys) == 1:
|
||||||
|
speaker_keys.append("SPEAKER_SILENT")
|
||||||
|
|
||||||
|
return speaker_keys
|
||||||
|
|
||||||
|
def _compute_ambiguous_assignments_vectorized(self, masks: Dict[str, np.ndarray],
|
||||||
|
speaker_keys: List[str],
|
||||||
|
ambiguous_mask: np.ndarray,
|
||||||
|
audio_length: int) -> np.ndarray:
|
||||||
|
"""Compute speaker assignments for ambiguous regions using vectorized operations."""
|
||||||
|
ambiguous_indices = np.where(ambiguous_mask)[0]
|
||||||
|
|
||||||
|
if len(ambiguous_indices) == 0:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
# Get speaker segments efficiently
|
||||||
|
speaker_segments = {}
|
||||||
|
for speaker in speaker_keys:
|
||||||
|
if speaker in masks and speaker != "SPEAKER_SILENT":
|
||||||
|
mask = masks[speaker] > 0.5
|
||||||
|
# Find segments using vectorized operations
|
||||||
|
diff = np.diff(np.concatenate(([False], mask, [False])).astype(int))
|
||||||
|
starts = np.where(diff == 1)[0]
|
||||||
|
ends = np.where(diff == -1)[0]
|
||||||
|
speaker_segments[speaker] = np.column_stack([starts, ends])
|
||||||
|
else:
|
||||||
|
speaker_segments[speaker] = np.array([]).reshape(0, 2)
|
||||||
|
|
||||||
|
# Vectorized distance calculations
|
||||||
|
distances = {}
|
||||||
|
for speaker in speaker_keys:
|
||||||
|
segments = speaker_segments[speaker]
|
||||||
|
if len(segments) == 0:
|
||||||
|
distances[speaker] = np.full(len(ambiguous_indices), np.inf)
|
||||||
|
else:
|
||||||
|
# Compute distances to all segments at once
|
||||||
|
distances[speaker] = self._compute_distances_to_segments(ambiguous_indices, segments)
|
||||||
|
|
||||||
|
# Assign based on minimum distance with late-audio bias
|
||||||
|
assignments = self._assign_based_on_distance(
|
||||||
|
distances, speaker_keys, ambiguous_indices, audio_length
|
||||||
|
)
|
||||||
|
|
||||||
|
return assignments
|
||||||
|
|
||||||
|
def _apply_minimum_duration_smoothing(self, masks: Dict[str, np.ndarray],
|
||||||
|
sample_rate: int, min_duration_ms: int = 600) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Apply minimum duration smoothing with STRICT timer enforcement.
|
||||||
|
Uses original both-speaking regions from diarization.
|
||||||
|
"""
|
||||||
|
xprint(f"Applying STRICT minimum duration smoothing ({min_duration_ms}ms)...")
|
||||||
|
|
||||||
|
min_samples = int(min_duration_ms * sample_rate / 1000)
|
||||||
|
speaker_keys = list(masks.keys())
|
||||||
|
|
||||||
|
if len(speaker_keys) != 2:
|
||||||
|
return masks
|
||||||
|
|
||||||
|
mask0 = masks[speaker_keys[0]]
|
||||||
|
mask1 = masks[speaker_keys[1]]
|
||||||
|
|
||||||
|
# Use original both-speaking regions from diarization
|
||||||
|
both_speaking_original = getattr(self, '_both_speaking_regions', np.zeros(len(mask0), dtype=bool))
|
||||||
|
|
||||||
|
# Identify regions based on original diarization info
|
||||||
|
ambiguous_original = (mask0 < 0.3) & (mask1 < 0.3) & ~both_speaking_original
|
||||||
|
|
||||||
|
# Clear dominance: one speaker higher, and not both-speaking or ambiguous
|
||||||
|
remaining_mask = ~both_speaking_original & ~ambiguous_original
|
||||||
|
speaker0_dominant = (mask0 > mask1) & remaining_mask
|
||||||
|
speaker1_dominant = (mask1 > mask0) & remaining_mask
|
||||||
|
|
||||||
|
# Create preference signal including both-speaking as valid state
|
||||||
|
# -1=ambiguous, 0=speaker0, 1=speaker1, 2=both_speaking
|
||||||
|
preference_signal = np.full(len(mask0), -1, dtype=int)
|
||||||
|
preference_signal[speaker0_dominant] = 0
|
||||||
|
preference_signal[speaker1_dominant] = 1
|
||||||
|
preference_signal[both_speaking_original] = 2
|
||||||
|
|
||||||
|
# STRICT state machine enforcement
|
||||||
|
smoothed_assignment = np.full(len(mask0), -1, dtype=int)
|
||||||
|
corrections = 0
|
||||||
|
|
||||||
|
# State variables
|
||||||
|
current_state = -1 # -1=unset, 0=speaker0, 1=speaker1, 2=both_speaking
|
||||||
|
samples_remaining = 0 # Samples remaining in current state's lock period
|
||||||
|
|
||||||
|
# Process each sample with STRICT enforcement
|
||||||
|
for i in range(len(preference_signal)):
|
||||||
|
preference = preference_signal[i]
|
||||||
|
|
||||||
|
# If we're in a lock period, enforce the current state
|
||||||
|
if samples_remaining > 0:
|
||||||
|
# Force current state regardless of preference
|
||||||
|
smoothed_assignment[i] = current_state
|
||||||
|
samples_remaining -= 1
|
||||||
|
|
||||||
|
# Count corrections if this differs from preference
|
||||||
|
if preference >= 0 and preference != current_state:
|
||||||
|
corrections += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Lock period expired - can consider new state
|
||||||
|
|
||||||
|
if preference >= 0:
|
||||||
|
# Clear preference available (including both-speaking)
|
||||||
|
if current_state != preference:
|
||||||
|
# Switch to new state and start new lock period
|
||||||
|
current_state = preference
|
||||||
|
samples_remaining = min_samples - 1 # -1 because we use this sample
|
||||||
|
|
||||||
|
smoothed_assignment[i] = current_state
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Ambiguous preference
|
||||||
|
if current_state >= 0:
|
||||||
|
# Continue with current state if we have one
|
||||||
|
smoothed_assignment[i] = current_state
|
||||||
|
else:
|
||||||
|
# No current state and ambiguous - leave as ambiguous
|
||||||
|
smoothed_assignment[i] = -1
|
||||||
|
|
||||||
|
# Convert back to masks based on smoothed assignment
|
||||||
|
smoothed_masks = {}
|
||||||
|
|
||||||
|
for i, speaker in enumerate(speaker_keys):
|
||||||
|
new_mask = np.zeros_like(mask0)
|
||||||
|
|
||||||
|
# Assign regions where this speaker is dominant
|
||||||
|
speaker_regions = smoothed_assignment == i
|
||||||
|
new_mask[speaker_regions] = 1.0
|
||||||
|
|
||||||
|
# Assign both-speaking regions (state 2) to both speakers
|
||||||
|
both_speaking_regions = smoothed_assignment == 2
|
||||||
|
new_mask[both_speaking_regions] = 1.0
|
||||||
|
|
||||||
|
# Handle ambiguous regions that remain unassigned
|
||||||
|
unassigned_ambiguous = smoothed_assignment == -1
|
||||||
|
if np.any(unassigned_ambiguous):
|
||||||
|
# Use original ambiguous values only for truly unassigned regions
|
||||||
|
original_ambiguous_mask = ambiguous_original & unassigned_ambiguous
|
||||||
|
new_mask[original_ambiguous_mask] = masks[speaker][original_ambiguous_mask]
|
||||||
|
|
||||||
|
smoothed_masks[speaker] = new_mask
|
||||||
|
|
||||||
|
# Calculate and xprint statistics
|
||||||
|
both_speaking_time = np.sum(smoothed_assignment == 2) / sample_rate
|
||||||
|
speaker0_time = np.sum(smoothed_assignment == 0) / sample_rate
|
||||||
|
speaker1_time = np.sum(smoothed_assignment == 1) / sample_rate
|
||||||
|
ambiguous_time = np.sum(smoothed_assignment == -1) / sample_rate
|
||||||
|
|
||||||
|
xprint(f" Both speaking clearly: {both_speaking_time:.1f}s")
|
||||||
|
xprint(f" {speaker_keys[0]} only: {speaker0_time:.1f}s")
|
||||||
|
xprint(f" {speaker_keys[1]} only: {speaker1_time:.1f}s")
|
||||||
|
xprint(f" Ambiguous (assigned): {ambiguous_time:.1f}s")
|
||||||
|
xprint(f" Enforced minimum duration on {corrections} samples ({corrections/sample_rate:.2f}s)")
|
||||||
|
|
||||||
|
return smoothed_masks
|
||||||
|
|
||||||
|
def _compute_distances_to_segments(self, indices: np.ndarray, segments: np.ndarray) -> np.ndarray:
|
||||||
|
"""Compute minimum distances from indices to segments (vectorized)."""
|
||||||
|
if len(segments) == 0:
|
||||||
|
return np.full(len(indices), np.inf)
|
||||||
|
|
||||||
|
# Broadcast for vectorized computation
|
||||||
|
indices_expanded = indices[:, np.newaxis] # Shape: (n_indices, 1)
|
||||||
|
starts = segments[:, 0] # Shape: (n_segments,)
|
||||||
|
ends = segments[:, 1] # Shape: (n_segments,)
|
||||||
|
|
||||||
|
# Compute distances to all segments
|
||||||
|
dist_to_start = np.maximum(0, starts - indices_expanded) # Shape: (n_indices, n_segments)
|
||||||
|
dist_from_end = np.maximum(0, indices_expanded - ends) # Shape: (n_indices, n_segments)
|
||||||
|
|
||||||
|
# Minimum of distance to start or from end for each segment
|
||||||
|
distances = np.minimum(dist_to_start, dist_from_end)
|
||||||
|
|
||||||
|
# Return minimum distance to any segment for each index
|
||||||
|
return np.min(distances, axis=1)
|
||||||
|
|
||||||
|
def _assign_based_on_distance(self, distances: Dict[str, np.ndarray],
|
||||||
|
speaker_keys: List[str],
|
||||||
|
ambiguous_indices: np.ndarray,
|
||||||
|
audio_length: int) -> np.ndarray:
|
||||||
|
"""Assign speakers based on distance with late-audio bias."""
|
||||||
|
speaker_0_distances = distances[speaker_keys[0]]
|
||||||
|
speaker_1_distances = distances[speaker_keys[1]]
|
||||||
|
|
||||||
|
# Basic assignment by minimum distance
|
||||||
|
assignments = (speaker_1_distances < speaker_0_distances).astype(int)
|
||||||
|
|
||||||
|
# Apply late-audio bias (vectorized)
|
||||||
|
late_threshold = int(audio_length * 0.6)
|
||||||
|
late_indices = ambiguous_indices > late_threshold
|
||||||
|
|
||||||
|
if np.any(late_indices) and len(speaker_keys) > 1:
|
||||||
|
# Simple late-audio bias: prefer speaker 1 in later parts
|
||||||
|
assignments[late_indices] = 1
|
||||||
|
|
||||||
|
return assignments
|
||||||
|
|
||||||
|
def _save_outputs_optimized(self, waveform: torch.Tensor, masks: Dict[str, np.ndarray],
|
||||||
|
sample_rate: int, audio_path: str, output1, output2) -> Dict[str, str]:
|
||||||
|
"""Optimized output saving with parallel processing."""
|
||||||
|
output_paths = {}
|
||||||
|
|
||||||
|
def save_speaker_audio(speaker_mask_pair, output):
|
||||||
|
speaker, mask = speaker_mask_pair
|
||||||
|
# Convert mask to tensor efficiently
|
||||||
|
mask_tensor = torch.from_numpy(mask).unsqueeze(0)
|
||||||
|
|
||||||
|
# Apply mask
|
||||||
|
masked_audio = waveform * mask_tensor
|
||||||
|
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
torchaudio.save(output, masked_audio, sample_rate)
|
||||||
|
|
||||||
|
xprint(f"✓ Saved {speaker}: {output}")
|
||||||
|
return speaker, output
|
||||||
|
|
||||||
|
# Use ThreadPoolExecutor for parallel saving
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
results = list(executor.map(save_speaker_audio, masks.items(), [output1, output2]))
|
||||||
|
|
||||||
|
output_paths = dict(results)
|
||||||
|
return output_paths
|
||||||
|
|
||||||
|
def print_summary(self, audio_path: str):
|
||||||
|
"""xprint diarization summary."""
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
diarization = self.perform_optimized_diarization(audio_path)
|
||||||
|
|
||||||
|
xprint("\n=== Diarization Summary ===")
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s")
|
||||||
|
|
||||||
|
def extract_dual_audio(audio, output1, output2, verbose = False):
|
||||||
|
global verbose_output
|
||||||
|
verbose_output = verbose
|
||||||
|
separator = OptimizedPyannote31SpeakerSeparator(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
vad_onset=0.2,
|
||||||
|
vad_offset=0.8
|
||||||
|
)
|
||||||
|
# Separate audio
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
outputs = separator.separate_audio(audio, output1, output2)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
|
||||||
|
for speaker, path in outputs.items():
|
||||||
|
xprint(f"{speaker}: {path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Optimized Pyannote 3.1 Speaker Separator")
|
||||||
|
parser.add_argument("--audio", required=True, help="Input audio file")
|
||||||
|
parser.add_argument("--output", required=True, help="Output directory")
|
||||||
|
parser.add_argument("--token", help="Hugging Face token")
|
||||||
|
parser.add_argument("--local-model", help="Path to local 3.1 model")
|
||||||
|
parser.add_argument("--summary", action="store_true", help="xprint summary")
|
||||||
|
|
||||||
|
# VAD sensitivity parameters
|
||||||
|
parser.add_argument("--vad-onset", type=float, default=0.2,
|
||||||
|
help="VAD onset threshold (lower = more sensitive to speech start, default: 0.2)")
|
||||||
|
parser.add_argument("--vad-offset", type=float, default=0.8,
|
||||||
|
help="VAD offset threshold (higher = keeps speech longer, default: 0.8)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
xprint("=== Optimized Pyannote 3.1 Speaker Separator ===")
|
||||||
|
xprint("Performance optimizations: vectorized operations, memory management, parallel processing")
|
||||||
|
xprint(f"Audio: {args.audio}")
|
||||||
|
xprint(f"Output: {args.output}")
|
||||||
|
xprint(f"VAD onset: {args.vad_onset}")
|
||||||
|
xprint(f"VAD offset: {args.vad_offset}")
|
||||||
|
xprint()
|
||||||
|
|
||||||
|
if not os.path.exists(args.audio):
|
||||||
|
xprint(f"ERROR: Audio file not found: {args.audio}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize with VAD parameters
|
||||||
|
separator = OptimizedPyannote31SpeakerSeparator(
|
||||||
|
args.token,
|
||||||
|
args.local_model,
|
||||||
|
vad_onset=args.vad_onset,
|
||||||
|
vad_offset=args.vad_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
# print summary if requested
|
||||||
|
if args.summary:
|
||||||
|
separator.print_summary(args.audio)
|
||||||
|
|
||||||
|
# Separate audio
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
audio_name = Path(args.audio).stem
|
||||||
|
output_filename = f"{audio_name}_speaker0.wav"
|
||||||
|
output_filename1 = f"{audio_name}_speaker1.wav"
|
||||||
|
output_path = os.path.join(args.output, output_filename)
|
||||||
|
output_path1 = os.path.join(args.output, output_filename1)
|
||||||
|
|
||||||
|
outputs = separator.separate_audio(args.audio, output_path, output_path1)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
|
||||||
|
for speaker, path in outputs.items():
|
||||||
|
xprint(f"{speaker}: {path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
xprint(f"ERROR: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
@ -38,7 +38,10 @@ pygame>=2.1.0
|
|||||||
sounddevice>=0.4.0
|
sounddevice>=0.4.0
|
||||||
# rembg==2.0.65
|
# rembg==2.0.65
|
||||||
torchdiffeq >= 0.2.5
|
torchdiffeq >= 0.2.5
|
||||||
# 'nitrous-ema',
|
|
||||||
# 'hydra_colorlog',
|
|
||||||
tensordict >= 0.6.1
|
tensordict >= 0.6.1
|
||||||
open_clip_torch >= 2.29.0
|
open_clip_torch >= 2.29.0
|
||||||
|
pyloudnorm
|
||||||
|
misaki
|
||||||
|
soundfile
|
||||||
|
# num2words
|
||||||
|
# spacy
|
||||||
@ -1,4 +1,3 @@
|
|||||||
from . import configs, distributed, modules
|
from . import configs, distributed, modules
|
||||||
from .image2video import WanI2V
|
from .any2video import WanAny2V
|
||||||
from .text2video import WanT2V
|
|
||||||
from .diffusion_forcing import DTT2V
|
from .diffusion_forcing import DTT2V
|
||||||
@ -13,6 +13,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.cuda.amp as amp
|
import torch.cuda.amp as amp
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
@ -21,14 +22,15 @@ from .distributed.fsdp import shard_model
|
|||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
from .modules.vae import WanVAE
|
from .modules.vae import WanVAE
|
||||||
|
from .modules.clip import CLIPModel
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
get_sampling_sigmas, retrieve_timesteps)
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||||
from .utils.vace_preprocessor import VaceVideoProcessor
|
from .utils.vace_preprocessor import VaceVideoProcessor
|
||||||
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
||||||
from wan.utils.utils import get_outpainting_frame_location
|
from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions
|
||||||
from wgp import update_loras_slists
|
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance
|
||||||
|
|
||||||
def optimized_scale(positive_flat, negative_flat):
|
def optimized_scale(positive_flat, negative_flat):
|
||||||
|
|
||||||
@ -43,14 +45,20 @@ def optimized_scale(positive_flat, negative_flat):
|
|||||||
|
|
||||||
return st_star
|
return st_star
|
||||||
|
|
||||||
|
def timestep_transform(t, shift=5.0, num_timesteps=1000 ):
|
||||||
class WanT2V:
|
t = t / num_timesteps
|
||||||
|
# shift the timestep based on ratio
|
||||||
|
new_t = shift * t / (1 + (shift - 1) * t)
|
||||||
|
new_t = new_t * num_timesteps
|
||||||
|
return new_t
|
||||||
|
|
||||||
|
|
||||||
|
class WanAny2V:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
rank=0,
|
|
||||||
model_filename = None,
|
model_filename = None,
|
||||||
model_type = None,
|
model_type = None,
|
||||||
base_model_type = None,
|
base_model_type = None,
|
||||||
@ -63,7 +71,7 @@ class WanT2V:
|
|||||||
):
|
):
|
||||||
self.device = torch.device(f"cuda")
|
self.device = torch.device(f"cuda")
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.VAE_dtype = VAE_dtype
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
@ -76,6 +84,14 @@ class WanT2V:
|
|||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
shard_fn= None)
|
shard_fn= None)
|
||||||
|
|
||||||
|
if hasattr(config, "clip_checkpoint"):
|
||||||
|
self.clip = CLIPModel(
|
||||||
|
dtype=config.clip_dtype,
|
||||||
|
device=self.device,
|
||||||
|
checkpoint_path=os.path.join(checkpoint_dir ,
|
||||||
|
config.clip_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer))
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
self.vae_stride = config.vae_stride
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
@ -83,22 +99,27 @@ class WanT2V:
|
|||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
# xmodel_filename = "c:/ml/multitalk/multitalk.safetensors"
|
||||||
from mmgp import offload
|
# config_filename= "configs/multitalk.json"
|
||||||
# model_filename = "c:/temp/vace1.3/diffusion_pytorch_model.safetensors"
|
# import json
|
||||||
# model_filename = "Vacefusionix_quanto_fp16_int8.safetensors"
|
# with open(config_filename, 'r', encoding='utf-8') as f:
|
||||||
# model_filename = "c:/temp/t2v/diffusion_pytorch_model-00001-of-00006.safetensors"
|
# config = json.load(f)
|
||||||
# config_filename= "c:/temp/t2v/t2v.json"
|
# from mmgp import safetensors2
|
||||||
|
# sd = safetensors2.torch_load_file(xmodel_filename)
|
||||||
|
|
||||||
base_config_file = f"configs/{base_model_type}.json"
|
base_config_file = f"configs/{base_model_type}.json"
|
||||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
forcedConfigPath = base_config_file if len(model_filename) > 1 or base_model_type in ["flf2v_720p"] else None
|
||||||
|
# model_filename[1] = xmodel_filename
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
|
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
|
||||||
|
# self.model = offload.load_model_data(self.model, xmodel_filename )
|
||||||
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
||||||
# self.model.to(torch.bfloat16)
|
# self.model.to(torch.bfloat16)
|
||||||
# self.model.cpu()
|
# self.model.cpu()
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||||
# dtype = torch.bfloat16
|
|
||||||
# offload.load_model_data(self.model, "ckpts/Wan14BT2VFusioniX_fp16.safetensors")
|
|
||||||
offload.change_dtype(self.model, dtype, True)
|
offload.change_dtype(self.model, dtype, True)
|
||||||
|
# offload.save_model(self.model, "multitalkbf16.safetensors", config_file_path=base_config_file, filter_sd=sd)
|
||||||
|
# offload.save_model(self.model, "multitalk_quanto_fp16.safetensors", do_quantize= True, config_file_path=base_config_file, filter_sd=sd)
|
||||||
|
|
||||||
# offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file)
|
# offload.save_model(self.model, "wan2.1_selforcing_fp16.safetensors", config_file_path=base_config_file)
|
||||||
# offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file)
|
# offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file)
|
||||||
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
||||||
@ -109,7 +130,7 @@ class WanT2V:
|
|||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
if base_model_type in ["vace_14B", "vace_1.3B"]:
|
if self.model.config.get("vace_in_dim", None) != None:
|
||||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||||
min_area=480*832,
|
min_area=480*832,
|
||||||
max_area=480*832,
|
max_area=480*832,
|
||||||
@ -121,6 +142,9 @@ class WanT2V:
|
|||||||
|
|
||||||
self.adapt_vace_model()
|
self.adapt_vace_model()
|
||||||
|
|
||||||
|
self.num_timesteps = 1000
|
||||||
|
self.use_timestep_transform = True
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
|
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
|
||||||
if ref_images is None:
|
if ref_images is None:
|
||||||
ref_images = [None] * len(frames)
|
ref_images = [None] * len(frames)
|
||||||
@ -134,10 +158,11 @@ class WanT2V:
|
|||||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||||
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
||||||
|
|
||||||
if overlapped_latents != None and False :
|
if overlapped_latents != None and False : # disabled as quality seems worse
|
||||||
# inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
|
# inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
|
||||||
for t in inactive:
|
for t in inactive:
|
||||||
t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
|
t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
|
||||||
|
overlapped_latents[: 0:1] = inactive[0][: 0:1]
|
||||||
|
|
||||||
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
||||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||||
@ -277,12 +302,14 @@ class WanT2V:
|
|||||||
image_sizes.append(src_video[i].shape[2:])
|
image_sizes.append(src_video[i].shape[2:])
|
||||||
for k, keep in enumerate(keep_video_guide_frames):
|
for k, keep in enumerate(keep_video_guide_frames):
|
||||||
if not keep:
|
if not keep:
|
||||||
src_video[i][:, k:k+1] = 0
|
pos = prepend_count + k
|
||||||
src_mask[i][:, k:k+1] = 1
|
src_video[i][:, pos:pos+1] = 0
|
||||||
|
src_mask[i][:, pos:pos+1] = 1
|
||||||
|
|
||||||
for k, frame in enumerate(inject_frames):
|
for k, frame in enumerate(inject_frames):
|
||||||
if frame != None:
|
if frame != None:
|
||||||
src_video[i][:, k:k+1], src_mask[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
pos = prepend_count + k
|
||||||
|
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||||
|
|
||||||
|
|
||||||
self.background_mask = None
|
self.background_mask = None
|
||||||
@ -322,158 +349,63 @@ class WanT2V:
|
|||||||
ref_vae_latents.append(img_vae_latent[0])
|
ref_vae_latents.append(img_vae_latent[0])
|
||||||
|
|
||||||
return torch.cat(ref_vae_latents, dim=1)
|
return torch.cat(ref_vae_latents, dim=1)
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames= None,
|
input_frames= None,
|
||||||
input_masks = None,
|
input_masks = None,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
input_video=None,
|
input_video=None,
|
||||||
denoising_strength = 1.0,
|
image_start = None,
|
||||||
target_camera=None,
|
image_end = None,
|
||||||
context_scale=None,
|
denoising_strength = 1.0,
|
||||||
width = 1280,
|
target_camera=None,
|
||||||
height = 720,
|
context_scale=None,
|
||||||
fit_into_canvas = True,
|
width = 1280,
|
||||||
frame_num=81,
|
height = 720,
|
||||||
shift=5.0,
|
fit_into_canvas = True,
|
||||||
sample_solver='unipc',
|
frame_num=81,
|
||||||
sampling_steps=50,
|
shift=5.0,
|
||||||
guide_scale=5.0,
|
sample_solver='unipc',
|
||||||
n_prompt="",
|
sampling_steps=50,
|
||||||
seed=-1,
|
guide_scale=5.0,
|
||||||
offload_model=True,
|
n_prompt="",
|
||||||
callback = None,
|
seed=-1,
|
||||||
enable_RIFLEx = None,
|
callback = None,
|
||||||
VAE_tile_size = 0,
|
enable_RIFLEx = None,
|
||||||
joint_pass = False,
|
VAE_tile_size = 0,
|
||||||
slg_layers = None,
|
joint_pass = False,
|
||||||
slg_start = 0.0,
|
slg_layers = None,
|
||||||
slg_end = 1.0,
|
slg_start = 0.0,
|
||||||
cfg_star_switch = True,
|
slg_end = 1.0,
|
||||||
cfg_zero_step = 5,
|
cfg_star_switch = True,
|
||||||
overlapped_latents = None,
|
cfg_zero_step = 5,
|
||||||
return_latent_slice = None,
|
audio_scale=None,
|
||||||
overlap_noise = 0,
|
audio_cfg_scale=None,
|
||||||
conditioning_latents_size = 0,
|
audio_proj=None,
|
||||||
keep_frames_parsed = [],
|
audio_context_lens=None,
|
||||||
model_filename = None,
|
overlapped_latents = None,
|
||||||
model_type = None,
|
return_latent_slice = None,
|
||||||
loras_slists = None,
|
overlap_noise = 0,
|
||||||
**bbargs
|
conditioning_latents_size = 0,
|
||||||
|
keep_frames_parsed = [],
|
||||||
|
model_type = None,
|
||||||
|
loras_slists = None,
|
||||||
|
offloadobj = None,
|
||||||
|
apg_switch = False,
|
||||||
|
**bbargs
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Generates video frames from text prompt using diffusion process.
|
if sample_solver =="euler":
|
||||||
|
# prepare timesteps
|
||||||
Args:
|
timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32))
|
||||||
input_prompt (`str`):
|
timesteps.append(0.)
|
||||||
Text prompt for content generation
|
timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
|
||||||
size (tupele[`int`], *optional*, defaults to (1280,720)):
|
if self.use_timestep_transform:
|
||||||
Controls video resolution, (width,height).
|
timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
|
||||||
frame_num (`int`, *optional*, defaults to 81):
|
sample_scheduler = None
|
||||||
How many frames to sample from a video. The number should be 4n+1
|
elif sample_solver == 'causvid':
|
||||||
shift (`float`, *optional*, defaults to 5.0):
|
|
||||||
Noise schedule shift parameter. Affects temporal dynamics
|
|
||||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
|
||||||
Solver used to sample the video.
|
|
||||||
sampling_steps (`int`, *optional*, defaults to 40):
|
|
||||||
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
|
||||||
guide_scale (`float`, *optional*, defaults 5.0):
|
|
||||||
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
|
||||||
n_prompt (`str`, *optional*, defaults to ""):
|
|
||||||
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
|
||||||
seed (`int`, *optional*, defaults to -1):
|
|
||||||
Random seed for noise generation. If -1, use random seed.
|
|
||||||
offload_model (`bool`, *optional*, defaults to True):
|
|
||||||
If True, offloads models to CPU during generation to save VRAM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor:
|
|
||||||
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
|
||||||
- C: Color channels (3 for RGB)
|
|
||||||
- N: Number of frames (81)
|
|
||||||
- H: Frame height (from size)
|
|
||||||
- W: Frame width from size)
|
|
||||||
"""
|
|
||||||
# preprocess
|
|
||||||
vace = "Vace" in model_filename
|
|
||||||
|
|
||||||
if n_prompt == "":
|
|
||||||
n_prompt = self.sample_neg_prompt
|
|
||||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
|
||||||
seed_g = torch.Generator(device=self.device)
|
|
||||||
seed_g.manual_seed(seed)
|
|
||||||
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
context = self.text_encoder([input_prompt], self.device)[0]
|
|
||||||
context_null = self.text_encoder([n_prompt], self.device)[0]
|
|
||||||
context = context.to(self.dtype)
|
|
||||||
context_null = context_null.to(self.dtype)
|
|
||||||
input_ref_images_neg = None
|
|
||||||
phantom = False
|
|
||||||
|
|
||||||
if target_camera != None:
|
|
||||||
width = input_video.shape[2]
|
|
||||||
height = input_video.shape[1]
|
|
||||||
input_video = input_video.to(dtype=self.dtype , device=self.device)
|
|
||||||
input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
|
||||||
source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device)
|
|
||||||
del input_video
|
|
||||||
# Process target camera (recammaster)
|
|
||||||
from wan.utils.cammmaster_tools import get_camera_embedding
|
|
||||||
cam_emb = get_camera_embedding(target_camera)
|
|
||||||
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
|
||||||
|
|
||||||
if denoising_strength < 1. and input_frames != None:
|
|
||||||
height, width = input_frames.shape[-2:]
|
|
||||||
source_latents = self.vae.encode([input_frames])[0]
|
|
||||||
|
|
||||||
if vace :
|
|
||||||
# vace context encode
|
|
||||||
input_frames = [u.to(self.device) for u in input_frames]
|
|
||||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
|
||||||
input_masks = [u.to(self.device) for u in input_masks]
|
|
||||||
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
|
||||||
previous_latents = None
|
|
||||||
# if overlapped_latents != None:
|
|
||||||
# input_ref_images = [u[-1:] for u in input_ref_images]
|
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
|
||||||
if self.background_mask != None:
|
|
||||||
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
|
||||||
mbg = self.vace_encode_masks(self.background_mask, None)
|
|
||||||
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
|
||||||
zz0[:, 0:1] = zzbg
|
|
||||||
mm0[:, 0:1] = mmbg
|
|
||||||
|
|
||||||
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
|
|
||||||
z = self.vace_latent(z0, m0)
|
|
||||||
|
|
||||||
target_shape = list(z0[0].shape)
|
|
||||||
target_shape[0] = int(target_shape[0] / 2)
|
|
||||||
else:
|
|
||||||
if input_ref_images != None: # Phantom Ref images
|
|
||||||
phantom = True
|
|
||||||
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
|
||||||
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
|
||||||
F = frame_num
|
|
||||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images.shape[1] if input_ref_images != None else 0),
|
|
||||||
height // self.vae_stride[1],
|
|
||||||
width // self.vae_stride[2])
|
|
||||||
|
|
||||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
|
||||||
(self.patch_size[1] * self.patch_size[2]) *
|
|
||||||
target_shape[1])
|
|
||||||
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
|
|
||||||
|
|
||||||
# evaluation mode
|
|
||||||
|
|
||||||
if sample_solver == 'causvid':
|
|
||||||
sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
|
sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
|
||||||
timesteps = torch.tensor([1000, 934, 862, 756, 603, 410, 250, 140, 74])[:sampling_steps].to(self.device)
|
timesteps = torch.tensor([1000, 934, 862, 756, 603, 410, 250, 140, 74])[:sampling_steps].to(self.device)
|
||||||
sample_scheduler.timesteps =timesteps
|
sample_scheduler.timesteps =timesteps
|
||||||
@ -496,55 +428,238 @@ class WanT2V:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported Scheduler {sample_solver}")
|
raise NotImplementedError(f"Unsupported Scheduler {sample_solver}")
|
||||||
|
|
||||||
|
seed_g = torch.Generator(device=self.device)
|
||||||
|
seed_g.manual_seed(seed)
|
||||||
|
|
||||||
# sample videos
|
kwargs = {'pipeline': self, 'callback': callback}
|
||||||
latents = noise[0]
|
|
||||||
del noise
|
|
||||||
|
|
||||||
injection_denoising_step = 0
|
if self._interrupt:
|
||||||
inject_from_start = False
|
return None
|
||||||
if denoising_strength < 1 and input_frames != None:
|
|
||||||
if len(keep_frames_parsed) == 0 or all(keep_frames_parsed): keep_frames_parsed = []
|
# Text Encoder
|
||||||
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
|
if n_prompt == "":
|
||||||
latent_keep_frames = []
|
n_prompt = self.sample_neg_prompt
|
||||||
if source_latents.shape[1] < latents.shape[1] or len(keep_frames_parsed) > 0:
|
context = self.text_encoder([input_prompt], self.device)[0]
|
||||||
inject_from_start = True
|
context_null = self.text_encoder([n_prompt], self.device)[0]
|
||||||
if len(keep_frames_parsed) >0 :
|
context = context.to(self.dtype)
|
||||||
latent_keep_frames =[keep_frames_parsed[0]]
|
context_null = context_null.to(self.dtype)
|
||||||
for i in range(1, len(keep_frames_parsed), 4):
|
# from mmgp import offload
|
||||||
latent_keep_frames.append(all(keep_frames_parsed[i:i+4]))
|
# offloadobj.unload_all()
|
||||||
|
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"]
|
||||||
|
phantom = model_type in ["phantom_1.3B", "phantom_14B"]
|
||||||
|
fantasy = model_type in ["fantasy"]
|
||||||
|
multitalk = model_type in ["multitalk", "vace_multitalk_14B"]
|
||||||
|
|
||||||
|
ref_images_count = 0
|
||||||
|
trim_frames = 0
|
||||||
|
extended_overlapped_latents = None
|
||||||
|
|
||||||
|
# image2video
|
||||||
|
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
|
||||||
|
if image_start != None:
|
||||||
|
any_end_frame = False
|
||||||
|
if input_frames != None:
|
||||||
|
_ , preframes_count, height, width = input_frames.shape
|
||||||
|
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||||
|
clip_context = self.clip.visual([input_frames[:, -1:]]) #.to(self.param_dtype)
|
||||||
|
input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype)
|
||||||
|
enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width),
|
||||||
|
device=self.device, dtype= self.VAE_dtype)],
|
||||||
|
dim = 1).to(self.device)
|
||||||
|
input_frames = None
|
||||||
else:
|
else:
|
||||||
timesteps = timesteps[injection_denoising_step:]
|
preframes_count = 1
|
||||||
if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps
|
image_start = TF.to_tensor(image_start)
|
||||||
if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
|
any_end_frame = image_end != None
|
||||||
injection_denoising_step = 0
|
add_frames_for_end_image = any_end_frame and model_type not in ["fun_inp_1.3B", "fun_inp", "i2v_720p"]
|
||||||
|
if any_end_frame:
|
||||||
|
image_end = TF.to_tensor(image_end)
|
||||||
|
if add_frames_for_end_image:
|
||||||
|
frame_num +=1
|
||||||
|
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||||
|
trim_frames = 1
|
||||||
|
|
||||||
|
h, w = image_start.shape[1:]
|
||||||
|
|
||||||
|
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||||
|
width, height = w, h
|
||||||
|
|
||||||
|
lat_h = round(
|
||||||
|
h // self.vae_stride[1] //
|
||||||
|
self.patch_size[1] * self.patch_size[1])
|
||||||
|
lat_w = round(
|
||||||
|
w // self.vae_stride[2] //
|
||||||
|
self.patch_size[2] * self.patch_size[2])
|
||||||
|
h = lat_h * self.vae_stride[1]
|
||||||
|
w = lat_w * self.vae_stride[2]
|
||||||
|
clip_image_size = self.clip.model.image_size
|
||||||
|
img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
||||||
|
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
|
||||||
|
image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
||||||
|
if image_end!= None:
|
||||||
|
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
||||||
|
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
|
||||||
|
image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
||||||
|
if image_end != None and model_type == "flf2v_720p":
|
||||||
|
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :]])
|
||||||
|
else:
|
||||||
|
clip_context = self.clip.visual([image_start[:, None, :, :]])
|
||||||
|
|
||||||
|
if any_end_frame:
|
||||||
|
enc= torch.concat([
|
||||||
|
img_interpolated,
|
||||||
|
torch.zeros( (3, frame_num-2, h, w), device=self.device, dtype= self.VAE_dtype),
|
||||||
|
img_interpolated2,
|
||||||
|
], dim=1).to(self.device)
|
||||||
|
else:
|
||||||
|
enc= torch.concat([
|
||||||
|
img_interpolated,
|
||||||
|
torch.zeros( (3, frame_num-1, h, w), device=self.device, dtype= self.VAE_dtype)
|
||||||
|
], dim=1).to(self.device)
|
||||||
|
|
||||||
|
image_start = image_end = img_interpolated = img_interpolated2 = None
|
||||||
|
|
||||||
|
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
||||||
|
if any_end_frame:
|
||||||
|
msk[:, preframes_count: -1] = 0
|
||||||
|
if add_frames_for_end_image:
|
||||||
|
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
|
||||||
|
else:
|
||||||
|
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
||||||
|
else:
|
||||||
|
msk[:, preframes_count:] = 0
|
||||||
|
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
|
|
||||||
|
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||||
|
overlapped_latents_frames_num = int(1 + (preframes_count-1) // 4)
|
||||||
|
if overlapped_latents != None:
|
||||||
|
# disabled because looks worse
|
||||||
|
if False and overlapped_latents_frames_num > 1: lat_y[:, 1:overlapped_latents_frames_num] = overlapped_latents[:, 1:]
|
||||||
|
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone()
|
||||||
|
y = torch.concat([msk, lat_y])
|
||||||
|
lat_y = None
|
||||||
|
kwargs.update({'clip_fea': clip_context, 'y': y})
|
||||||
|
|
||||||
|
# Recam Master
|
||||||
|
if target_camera != None:
|
||||||
|
width = input_video.shape[2]
|
||||||
|
height = input_video.shape[1]
|
||||||
|
input_video = input_video.to(dtype=self.dtype , device=self.device)
|
||||||
|
input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
|
||||||
|
source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device)
|
||||||
|
del input_video
|
||||||
|
# Process target camera (recammaster)
|
||||||
|
from wan.utils.cammmaster_tools import get_camera_embedding
|
||||||
|
cam_emb = get_camera_embedding(target_camera)
|
||||||
|
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
||||||
|
kwargs['cam_emb'] = cam_emb
|
||||||
|
|
||||||
|
# Video 2 Video
|
||||||
|
if denoising_strength < 1. and input_frames != None:
|
||||||
|
height, width = input_frames.shape[-2:]
|
||||||
|
source_latents = self.vae.encode([input_frames])[0]
|
||||||
|
injection_denoising_step = 0
|
||||||
|
inject_from_start = False
|
||||||
|
if input_frames != None and denoising_strength < 1 :
|
||||||
|
if overlapped_latents != None:
|
||||||
|
overlapped_latents_frames_num = overlapped_latents.shape[1]
|
||||||
|
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
|
||||||
|
else:
|
||||||
|
overlapped_latents_frames_num = overlapped_frames_num = 0
|
||||||
|
if len(keep_frames_parsed) == 0 or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
|
||||||
|
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
|
||||||
|
latent_keep_frames = []
|
||||||
|
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
|
||||||
|
inject_from_start = True
|
||||||
|
if len(keep_frames_parsed) >0 :
|
||||||
|
if overlapped_frames_num > 0: keep_frames_parsed = [True] * overlapped_frames_num + keep_frames_parsed
|
||||||
|
latent_keep_frames =[keep_frames_parsed[0]]
|
||||||
|
for i in range(1, len(keep_frames_parsed), 4):
|
||||||
|
latent_keep_frames.append(all(keep_frames_parsed[i:i+4]))
|
||||||
|
else:
|
||||||
|
timesteps = timesteps[injection_denoising_step:]
|
||||||
|
if hasattr(sample_scheduler, "timesteps"): sample_scheduler.timesteps = timesteps
|
||||||
|
if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
|
||||||
|
injection_denoising_step = 0
|
||||||
|
|
||||||
|
# Phantom
|
||||||
|
if phantom:
|
||||||
|
input_ref_images_neg = None
|
||||||
|
if input_ref_images != None: # Phantom Ref images
|
||||||
|
input_ref_images = self.get_vae_latents(input_ref_images, self.device)
|
||||||
|
input_ref_images_neg = torch.zeros_like(input_ref_images)
|
||||||
|
ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
|
||||||
|
|
||||||
|
# Vace
|
||||||
|
if vace :
|
||||||
|
# vace context encode
|
||||||
|
input_frames = [u.to(self.device) for u in input_frames]
|
||||||
|
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
||||||
|
input_masks = [u.to(self.device) for u in input_masks]
|
||||||
|
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
||||||
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||||
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
|
if self.background_mask != None:
|
||||||
|
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
||||||
|
mbg = self.vace_encode_masks(self.background_mask, None)
|
||||||
|
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
||||||
|
zz0[:, 0:1] = zzbg
|
||||||
|
mm0[:, 0:1] = mmbg
|
||||||
|
|
||||||
|
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
|
||||||
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
|
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
|
||||||
|
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
||||||
|
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
|
||||||
|
if overlapped_latents != None :
|
||||||
|
overlapped_latents_size = overlapped_latents.shape[1]
|
||||||
|
extended_overlapped_latents = z[0][0:16, 0:overlapped_latents_size + ref_images_count].clone()
|
||||||
|
|
||||||
|
target_shape = list(z0[0].shape)
|
||||||
|
target_shape[0] = int(target_shape[0] / 2)
|
||||||
|
lat_h, lat_w = target_shape[-2:]
|
||||||
|
height = self.vae_stride[1] * lat_h
|
||||||
|
width = self.vae_stride[2] * lat_w
|
||||||
|
|
||||||
|
else:
|
||||||
|
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
|
||||||
|
|
||||||
|
if multitalk and audio_proj != None:
|
||||||
|
from wan.multitalk.multitalk import get_target_masks
|
||||||
|
audio_proj = [audio.to(self.dtype) for audio in audio_proj]
|
||||||
|
human_no = len(audio_proj[0])
|
||||||
|
token_ref_target_masks = get_target_masks(human_no, lat_h, lat_w, height, width, face_scale = 0.05, bbox = None).to(self.dtype) if human_no > 1 else None
|
||||||
|
|
||||||
|
if fantasy and audio_proj != None:
|
||||||
|
kwargs.update({ "audio_proj": audio_proj.to(self.dtype), "audio_context_lens": audio_context_lens, })
|
||||||
|
|
||||||
|
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Ropes
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
shape = list(latents.shape[1:])
|
shape = list(target_shape[1:])
|
||||||
shape[0] *= 2
|
shape[0] *= 2
|
||||||
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
||||||
else:
|
else:
|
||||||
freqs = get_rotary_pos_embed(latents.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
|
|
||||||
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
kwargs["freqs"] = freqs
|
||||||
|
|
||||||
if target_camera != None:
|
|
||||||
kwargs.update({'cam_emb': cam_emb})
|
|
||||||
|
|
||||||
if vace:
|
|
||||||
ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
|
|
||||||
context_scale = context_scale if context_scale != None else [1.0] * len(z)
|
|
||||||
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
|
||||||
if overlapped_latents != None :
|
|
||||||
overlapped_latents_size = overlapped_latents.shape[1] + 1
|
|
||||||
# overlapped_latents_size = 3
|
|
||||||
z_reactive = [ zz[0:16, 0:overlapped_latents_size + ref_images_count].clone() for zz in z]
|
|
||||||
|
|
||||||
|
# Steps Skipping
|
||||||
cache_type = self.model.enable_cache
|
cache_type = self.model.enable_cache
|
||||||
if cache_type != None:
|
if cache_type != None:
|
||||||
x_count = 3 if phantom else 2
|
x_count = 3 if phantom or fantasy or multitalk else 2
|
||||||
self.model.previous_residual = [None] * x_count
|
self.model.previous_residual = [None] * x_count
|
||||||
if cache_type == "tea":
|
if cache_type == "tea":
|
||||||
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
||||||
@ -552,6 +667,7 @@ class WanT2V:
|
|||||||
self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
||||||
self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count
|
self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count
|
||||||
self.model.one_for_all = x_count > 2
|
self.model.one_for_all = x_count > 2
|
||||||
|
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, True)
|
callback(-1, None, True)
|
||||||
|
|
||||||
@ -560,15 +676,29 @@ class WanT2V:
|
|||||||
if chipmunk:
|
if chipmunk:
|
||||||
self.model.setup_chipmunk()
|
self.model.setup_chipmunk()
|
||||||
|
|
||||||
updated_num_steps= len(timesteps)
|
# init denoising
|
||||||
|
updated_num_steps= len(timesteps)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
|
from wgp import update_loras_slists
|
||||||
update_loras_slists(self.model, loras_slists, updated_num_steps)
|
update_loras_slists(self.model, loras_slists, updated_num_steps)
|
||||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||||
|
|
||||||
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
|
if sample_scheduler != None:
|
||||||
|
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
|
||||||
|
|
||||||
|
latents = torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
|
||||||
|
if apg_switch != 0:
|
||||||
|
apg_momentum = -0.75
|
||||||
|
apg_norm_threshold = 55
|
||||||
|
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
|
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||||
|
|
||||||
|
# denoising
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
timestep = [t]
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
|
timestep = torch.stack([t])
|
||||||
|
kwargs.update({"t": timestep, "current_step": i})
|
||||||
|
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||||
|
|
||||||
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
|
if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
|
||||||
sigma = t / 1000
|
sigma = t / 1000
|
||||||
@ -585,98 +715,119 @@ class WanT2V:
|
|||||||
latents = noise * sigma + (1 - sigma) * source_latents
|
latents = noise * sigma + (1 - sigma) * source_latents
|
||||||
noise = None
|
noise = None
|
||||||
|
|
||||||
if overlapped_latents != None :
|
if extended_overlapped_latents != None:
|
||||||
overlap_noise_factor = overlap_noise / 1000
|
|
||||||
latent_noise_factor = t / 1000
|
latent_noise_factor = t / 1000
|
||||||
for zz, zz_r, ll in zip(z, z_reactive, [latents, None]): # extra None for second control net
|
latents[:, 0:extended_overlapped_latents.shape[1]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
|
||||||
zz[0:16, ref_images_count:overlapped_latents_size + ref_images_count] = zz_r[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(zz_r[:, ref_images_count:] ) * overlap_noise_factor
|
if vace:
|
||||||
if ll != None:
|
overlap_noise_factor = overlap_noise / 1000
|
||||||
ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor
|
for zz in z:
|
||||||
|
zz[0:16, ref_images_count:extended_overlapped_latents.shape[1] ] = extended_overlapped_latents[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[:, ref_images_count:] ) * overlap_noise_factor
|
||||||
|
|
||||||
if target_camera != None:
|
if target_camera != None:
|
||||||
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
|
||||||
|
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
if phantom:
|
||||||
timestep = torch.stack(timestep)
|
gen_args = {
|
||||||
kwargs["current_step"] = i
|
"x" : ([ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images], dim=1) ] * 2 +
|
||||||
kwargs["t"] = timestep
|
[ torch.cat([latent_model_input[:,:-ref_images_count], input_ref_images_neg], dim=1)]),
|
||||||
if guide_scale == 1:
|
"context": [context, context_null, context_null] ,
|
||||||
noise_pred = self.model( [latent_model_input], x_id = 0, context = [context], **kwargs)[0]
|
}
|
||||||
if self._interrupt:
|
elif fantasy:
|
||||||
return None
|
gen_args = {
|
||||||
elif joint_pass:
|
"x" : [latent_model_input, latent_model_input, latent_model_input],
|
||||||
if phantom:
|
"context" : [context, context_null, context_null],
|
||||||
pos_it, pos_i, neg = self.model(
|
"audio_scale": [audio_scale, None, None ]
|
||||||
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ] * 2 +
|
}
|
||||||
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1)],
|
elif multitalk:
|
||||||
context = [context, context_null, context_null], **kwargs)
|
gen_args = {
|
||||||
else:
|
"x" : [latent_model_input, latent_model_input, latent_model_input],
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
"context" : [context, context_null, context_null],
|
||||||
[latent_model_input, latent_model_input], context = [context, context_null], **kwargs)
|
"multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
|
||||||
if self._interrupt:
|
"multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None]
|
||||||
return None
|
}
|
||||||
else:
|
else:
|
||||||
if phantom:
|
gen_args = {
|
||||||
pos_it = self.model(
|
"x" : [latent_model_input, latent_model_input],
|
||||||
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 0, context = [context], **kwargs
|
"context": [context, context_null]
|
||||||
)[0]
|
}
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
pos_i = self.model(
|
|
||||||
[ torch.cat([latent_model_input[:,:-input_ref_images.shape[1]], input_ref_images], dim=1) ], x_id = 1, context = [context_null],**kwargs
|
|
||||||
)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
neg = self.model(
|
|
||||||
[ torch.cat([latent_model_input[:,:-input_ref_images_neg.shape[1]], input_ref_images_neg], dim=1) ], x_id = 2, context = [context_null], **kwargs
|
|
||||||
)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
noise_pred_cond = self.model(
|
|
||||||
[latent_model_input], x_id = 0, context = [context], **kwargs)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
noise_pred_uncond = self.model(
|
|
||||||
[latent_model_input], x_id = 1, context = [context_null], **kwargs)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# del latent_model_input
|
if joint_pass and guide_scale > 1:
|
||||||
|
ret_values = self.model( **gen_args , **kwargs)
|
||||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
size = 1 if guide_scale == 1 else len(gen_args["x"])
|
||||||
|
ret_values = [None] * size
|
||||||
|
for x_id in range(size):
|
||||||
|
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
|
||||||
|
ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0]
|
||||||
|
if self._interrupt:
|
||||||
|
return None
|
||||||
|
sub_gen_args = None
|
||||||
if guide_scale == 1:
|
if guide_scale == 1:
|
||||||
pass
|
noise_pred = ret_values[0]
|
||||||
elif phantom:
|
elif phantom:
|
||||||
guide_scale_img= 5.0
|
guide_scale_img= 5.0
|
||||||
guide_scale_text= guide_scale #7.5
|
guide_scale_text= guide_scale #7.5
|
||||||
|
pos_it, pos_i, neg = ret_values
|
||||||
noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
|
noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
|
||||||
|
pos_it = pos_i = neg = None
|
||||||
|
elif fantasy:
|
||||||
|
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = ret_values
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
||||||
|
noise_pred_noaudio = None
|
||||||
|
elif multitalk:
|
||||||
|
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
|
||||||
|
if apg_switch != 0:
|
||||||
|
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
|
||||||
|
noise_pred_cond,
|
||||||
|
momentum_buffer=text_momentumbuffer,
|
||||||
|
norm_threshold=apg_norm_threshold) \
|
||||||
|
+ (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond,
|
||||||
|
noise_pred_cond,
|
||||||
|
momentum_buffer=audio_momentumbuffer,
|
||||||
|
norm_threshold=apg_norm_threshold)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond)
|
||||||
|
noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = None
|
||||||
else:
|
else:
|
||||||
noise_pred_text = noise_pred_cond
|
noise_pred_cond, noise_pred_uncond = ret_values
|
||||||
if cfg_star_switch:
|
if apg_switch != 0:
|
||||||
positive_flat = noise_pred_text.view(batch_size, -1)
|
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_uncond,
|
||||||
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
noise_pred_cond,
|
||||||
|
momentum_buffer=text_momentumbuffer,
|
||||||
|
norm_threshold=apg_norm_threshold)
|
||||||
|
else:
|
||||||
|
noise_pred_text = noise_pred_cond
|
||||||
|
if cfg_star_switch:
|
||||||
|
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
||||||
|
positive_flat = noise_pred_text.view(batch_size, -1)
|
||||||
|
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
||||||
|
|
||||||
alpha = optimized_scale(positive_flat,negative_flat)
|
alpha = optimized_scale(positive_flat,negative_flat)
|
||||||
alpha = alpha.view(batch_size, 1, 1, 1)
|
alpha = alpha.view(batch_size, 1, 1, 1)
|
||||||
|
|
||||||
if (i <= cfg_zero_step):
|
if (i <= cfg_zero_step):
|
||||||
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
|
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
|
||||||
else:
|
else:
|
||||||
noise_pred_uncond *= alpha
|
noise_pred_uncond *= alpha
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
noise_pred_uncond, noise_pred_cond, noise_pred_text, pos_it, pos_i, neg = None, None, None, None, None, None
|
ret_values = noise_pred_uncond = noise_pred_cond = noise_pred_text = neg = None
|
||||||
temp_x0 = sample_scheduler.step(
|
|
||||||
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
if sample_solver == "euler":
|
||||||
t,
|
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
|
||||||
latents.unsqueeze(0),
|
dt = dt / self.num_timesteps
|
||||||
# return_dict=False,
|
latents = latents - noise_pred * dt[:, None, None, None]
|
||||||
**scheduler_kwargs)[0]
|
else:
|
||||||
latents = temp_x0.squeeze(0)
|
temp_x0 = sample_scheduler.step(
|
||||||
del temp_x0
|
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
||||||
|
t,
|
||||||
|
latents.unsqueeze(0),
|
||||||
|
**scheduler_kwargs)[0]
|
||||||
|
latents = temp_x0.squeeze(0)
|
||||||
|
del temp_x0
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latents, False)
|
callback(i, latents, False)
|
||||||
@ -684,23 +835,19 @@ class WanT2V:
|
|||||||
x0 = [latents]
|
x0 = [latents]
|
||||||
|
|
||||||
if chipmunk:
|
if chipmunk:
|
||||||
self.model.release_chipmunk() # need to add it at every exit when in prof
|
self.model.release_chipmunk() # need to add it at every exit when in prod
|
||||||
|
|
||||||
if return_latent_slice != None:
|
if return_latent_slice != None:
|
||||||
if overlapped_latents != None:
|
|
||||||
# latents [:, 1:] = self.toto
|
|
||||||
for zz, zz_r, ll in zip(z, z_reactive, [latents]):
|
|
||||||
ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r
|
|
||||||
|
|
||||||
latent_slice = latents[:, return_latent_slice].clone()
|
latent_slice = latents[:, return_latent_slice].clone()
|
||||||
if input_frames == None:
|
if vace:
|
||||||
if phantom:
|
|
||||||
# phantom post processing
|
|
||||||
x0 = [x0_[:,:-input_ref_images.shape[1]] for x0_ in x0]
|
|
||||||
videos = self.vae.decode(x0, VAE_tile_size)
|
|
||||||
else:
|
|
||||||
# vace post processing
|
# vace post processing
|
||||||
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
||||||
|
else:
|
||||||
|
if phantom and input_ref_images != None:
|
||||||
|
trim_frames = input_ref_images.shape[1]
|
||||||
|
if trim_frames > 0: x0 = [x0_[:,:-trim_frames] for x0_ in x0]
|
||||||
|
videos = self.vae.decode(x0, VAE_tile_size)
|
||||||
|
|
||||||
if return_latent_slice != None:
|
if return_latent_slice != None:
|
||||||
return { "x" : videos[0], "latent_slice" : latent_slice }
|
return { "x" : videos[0], "latent_slice" : latent_slice }
|
||||||
return videos[0]
|
return videos[0]
|
||||||
@ -1,439 +0,0 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import gc
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from functools import partial
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.cuda.amp as amp
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torchvision.transforms.functional as TF
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
|
||||||
from .modules.clip import CLIPModel
|
|
||||||
from .modules.model import WanModel
|
|
||||||
from .modules.t5 import T5EncoderModel
|
|
||||||
from .modules.vae import WanVAE
|
|
||||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|
||||||
get_sampling_sigmas, retrieve_timesteps)
|
|
||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|
||||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
|
||||||
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
|
|
||||||
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
|
||||||
|
|
||||||
def optimized_scale(positive_flat, negative_flat):
|
|
||||||
|
|
||||||
# Calculate dot production
|
|
||||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
|
||||||
|
|
||||||
# Squared norm of uncondition
|
|
||||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
|
||||||
|
|
||||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
|
||||||
st_star = dot_product / squared_norm
|
|
||||||
|
|
||||||
return st_star
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WanI2V:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
checkpoint_dir,
|
|
||||||
model_filename = None,
|
|
||||||
model_type = None,
|
|
||||||
base_model_type= None,
|
|
||||||
text_encoder_filename= None,
|
|
||||||
quantizeTransformer = False,
|
|
||||||
dtype = torch.bfloat16,
|
|
||||||
VAE_dtype = torch.float32,
|
|
||||||
save_quantized = False,
|
|
||||||
mixed_precision_transformer = False
|
|
||||||
):
|
|
||||||
self.device = torch.device(f"cuda")
|
|
||||||
self.config = config
|
|
||||||
self.dtype = dtype
|
|
||||||
self.VAE_dtype = VAE_dtype
|
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
|
||||||
self.param_dtype = config.param_dtype
|
|
||||||
# shard_fn = partial(shard_model, device_id=device_id)
|
|
||||||
self.text_encoder = T5EncoderModel(
|
|
||||||
text_len=config.text_len,
|
|
||||||
dtype=config.t5_dtype,
|
|
||||||
device=torch.device('cpu'),
|
|
||||||
checkpoint_path=text_encoder_filename,
|
|
||||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
|
||||||
shard_fn=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.vae_stride = config.vae_stride
|
|
||||||
self.patch_size = config.patch_size
|
|
||||||
self.vae = WanVAE(
|
|
||||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype = VAE_dtype,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
self.clip = CLIPModel(
|
|
||||||
dtype=config.clip_dtype,
|
|
||||||
device=self.device,
|
|
||||||
checkpoint_path=os.path.join(checkpoint_dir ,
|
|
||||||
config.clip_checkpoint),
|
|
||||||
tokenizer_path=os.path.join(checkpoint_dir , config.clip_tokenizer))
|
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {model_filename[-1]}")
|
|
||||||
from mmgp import offload
|
|
||||||
|
|
||||||
# fantasy = torch.load("c:/temp/fantasy.ckpt")
|
|
||||||
# proj_model = fantasy["proj_model"]
|
|
||||||
# audio_processor = fantasy["audio_processor"]
|
|
||||||
# offload.safetensors2.torch_write_file(proj_model, "proj_model.safetensors")
|
|
||||||
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor.safetensors")
|
|
||||||
# for k,v in audio_processor.items():
|
|
||||||
# audio_processor[k] = v.to(torch.bfloat16)
|
|
||||||
# with open("fantasy_config.json", "r", encoding="utf-8") as reader:
|
|
||||||
# config_text = reader.read()
|
|
||||||
# config_json = json.loads(config_text)
|
|
||||||
# offload.safetensors2.torch_write_file(audio_processor, "audio_processor_bf16.safetensors", config=config_json)
|
|
||||||
# model_filename = [model_filename, "audio_processor_bf16.safetensors"]
|
|
||||||
# model_filename = "c:/temp/i2v480p/diffusion_pytorch_model-00001-of-00007.safetensors"
|
|
||||||
# dtype = torch.float16
|
|
||||||
base_config_file = f"configs/{base_model_type}.json"
|
|
||||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
|
||||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath= base_config_file, forcedConfigPath= forcedConfigPath)
|
|
||||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
|
||||||
offload.change_dtype(self.model, dtype, True)
|
|
||||||
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
|
|
||||||
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
|
||||||
# offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
|
|
||||||
|
|
||||||
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
|
||||||
self.model.eval().requires_grad_(False)
|
|
||||||
if save_quantized:
|
|
||||||
from wgp import save_quantized_model
|
|
||||||
save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
|
|
||||||
|
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
|
||||||
|
|
||||||
def generate(self,
|
|
||||||
input_prompt,
|
|
||||||
image_start,
|
|
||||||
image_end = None,
|
|
||||||
height =720,
|
|
||||||
width = 1280,
|
|
||||||
fit_into_canvas = True,
|
|
||||||
frame_num=81,
|
|
||||||
shift=5.0,
|
|
||||||
sample_solver='unipc',
|
|
||||||
sampling_steps=40,
|
|
||||||
guide_scale=5.0,
|
|
||||||
n_prompt="",
|
|
||||||
seed=-1,
|
|
||||||
callback = None,
|
|
||||||
enable_RIFLEx = False,
|
|
||||||
VAE_tile_size= 0,
|
|
||||||
joint_pass = False,
|
|
||||||
slg_layers = None,
|
|
||||||
slg_start = 0.0,
|
|
||||||
slg_end = 1.0,
|
|
||||||
cfg_star_switch = True,
|
|
||||||
cfg_zero_step = 5,
|
|
||||||
audio_scale=None,
|
|
||||||
audio_cfg_scale=None,
|
|
||||||
audio_proj=None,
|
|
||||||
audio_context_lens=None,
|
|
||||||
model_filename = None,
|
|
||||||
offloadobj = None,
|
|
||||||
**bbargs
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Generates video frames from input image and text prompt using diffusion process.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_prompt (`str`):
|
|
||||||
Text prompt for content generation.
|
|
||||||
image_start (PIL.Image.Image):
|
|
||||||
Input image tensor. Shape: [3, H, W]
|
|
||||||
max_area (`int`, *optional*, defaults to 720*1280):
|
|
||||||
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
|
||||||
frame_num (`int`, *optional*, defaults to 81):
|
|
||||||
How many frames to sample from a video. The number should be 4n+1
|
|
||||||
shift (`float`, *optional*, defaults to 5.0):
|
|
||||||
Noise schedule shift parameter. Affects temporal dynamics
|
|
||||||
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
|
||||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
|
||||||
Solver used to sample the video.
|
|
||||||
sampling_steps (`int`, *optional*, defaults to 40):
|
|
||||||
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
|
||||||
guide_scale (`float`, *optional*, defaults 5.0):
|
|
||||||
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
|
||||||
n_prompt (`str`, *optional*, defaults to ""):
|
|
||||||
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
|
||||||
seed (`int`, *optional*, defaults to -1):
|
|
||||||
Random seed for noise generation. If -1, use random seed
|
|
||||||
offload_model (`bool`, *optional*, defaults to True):
|
|
||||||
If True, offloads models to CPU during generation to save VRAM
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor:
|
|
||||||
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
|
||||||
- C: Color channels (3 for RGB)
|
|
||||||
- N: Number of frames (81)
|
|
||||||
- H: Frame height (from max_area)
|
|
||||||
- W: Frame width from max_area)
|
|
||||||
"""
|
|
||||||
|
|
||||||
add_frames_for_end_image = "image2video" in model_filename or "fantasy" in model_filename
|
|
||||||
|
|
||||||
image_start = TF.to_tensor(image_start)
|
|
||||||
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
|
||||||
any_end_frame = image_end !=None
|
|
||||||
if any_end_frame:
|
|
||||||
any_end_frame = True
|
|
||||||
image_end = TF.to_tensor(image_end)
|
|
||||||
if add_frames_for_end_image:
|
|
||||||
frame_num +=1
|
|
||||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
|
||||||
|
|
||||||
h, w = image_start.shape[1:]
|
|
||||||
|
|
||||||
h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
|
||||||
|
|
||||||
lat_h = round(
|
|
||||||
h // self.vae_stride[1] //
|
|
||||||
self.patch_size[1] * self.patch_size[1])
|
|
||||||
lat_w = round(
|
|
||||||
w // self.vae_stride[2] //
|
|
||||||
self.patch_size[2] * self.patch_size[2])
|
|
||||||
h = lat_h * self.vae_stride[1]
|
|
||||||
w = lat_w * self.vae_stride[2]
|
|
||||||
|
|
||||||
clip_image_size = self.clip.model.image_size
|
|
||||||
img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
|
||||||
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
|
|
||||||
image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
|
||||||
if image_end!= None:
|
|
||||||
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
|
||||||
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
|
|
||||||
image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
|
||||||
|
|
||||||
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
|
||||||
|
|
||||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
|
||||||
seed_g = torch.Generator(device=self.device)
|
|
||||||
seed_g.manual_seed(seed)
|
|
||||||
noise = torch.randn(16, lat_frames, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
|
|
||||||
|
|
||||||
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
|
||||||
if any_end_frame:
|
|
||||||
msk[:, 1: -1] = 0
|
|
||||||
if add_frames_for_end_image:
|
|
||||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
|
|
||||||
else:
|
|
||||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
msk[:, 1:] = 0
|
|
||||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
|
||||||
msk = msk.transpose(1, 2)[0]
|
|
||||||
|
|
||||||
if n_prompt == "":
|
|
||||||
n_prompt = self.sample_neg_prompt
|
|
||||||
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# preprocess
|
|
||||||
context = self.text_encoder([input_prompt], self.device)[0]
|
|
||||||
context_null = self.text_encoder([n_prompt], self.device)[0]
|
|
||||||
context = context.to(self.dtype)
|
|
||||||
context_null = context_null.to(self.dtype)
|
|
||||||
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
clip_context = self.clip.visual([image_start[:, None, :, :]])
|
|
||||||
|
|
||||||
from mmgp import offload
|
|
||||||
offloadobj.unload_all()
|
|
||||||
if any_end_frame:
|
|
||||||
mean2 = 0
|
|
||||||
enc= torch.concat([
|
|
||||||
img_interpolated,
|
|
||||||
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.VAE_dtype),
|
|
||||||
img_interpolated2,
|
|
||||||
], dim=1).to(self.device)
|
|
||||||
else:
|
|
||||||
enc= torch.concat([
|
|
||||||
img_interpolated,
|
|
||||||
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.VAE_dtype)
|
|
||||||
], dim=1).to(self.device)
|
|
||||||
image_start, image_end, img_interpolated, img_interpolated2 = None, None, None, None
|
|
||||||
|
|
||||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
|
||||||
y = torch.concat([msk, lat_y])
|
|
||||||
lat_y = None
|
|
||||||
|
|
||||||
|
|
||||||
# evaluation mode
|
|
||||||
if sample_solver == 'causvid':
|
|
||||||
sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
|
|
||||||
timesteps = torch.tensor([1000, 934, 862, 756, 603, 410, 250, 140, 74])[:sampling_steps].to(self.device)
|
|
||||||
sample_scheduler.timesteps =timesteps
|
|
||||||
sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.], device=self.device)])
|
|
||||||
elif sample_solver == 'unipc' or sample_solver == "":
|
|
||||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
||||||
num_train_timesteps=self.num_train_timesteps,
|
|
||||||
shift=1,
|
|
||||||
use_dynamic_shifting=False)
|
|
||||||
sample_scheduler.set_timesteps(
|
|
||||||
sampling_steps, device=self.device, shift=shift)
|
|
||||||
timesteps = sample_scheduler.timesteps
|
|
||||||
elif sample_solver == 'dpm++':
|
|
||||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
|
||||||
num_train_timesteps=self.num_train_timesteps,
|
|
||||||
shift=1,
|
|
||||||
use_dynamic_shifting=False)
|
|
||||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
|
||||||
timesteps, _ = retrieve_timesteps(
|
|
||||||
sample_scheduler,
|
|
||||||
device=self.device,
|
|
||||||
sigmas=sampling_sigmas)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Unsupported scheduler.")
|
|
||||||
|
|
||||||
# sample videos
|
|
||||||
latent = noise
|
|
||||||
batch_size = 1
|
|
||||||
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
|
||||||
|
|
||||||
kwargs = { 'clip_fea': clip_context, 'y': y, 'freqs' : freqs, 'pipeline' : self, 'callback' : callback }
|
|
||||||
|
|
||||||
if audio_proj != None:
|
|
||||||
kwargs.update({
|
|
||||||
"audio_proj": audio_proj.to(self.dtype),
|
|
||||||
"audio_context_lens": audio_context_lens,
|
|
||||||
})
|
|
||||||
cache_type = self.model.enable_cache
|
|
||||||
if cache_type != None:
|
|
||||||
x_count = 3 if audio_cfg_scale !=None else 2
|
|
||||||
self.model.previous_residual = [None] * x_count
|
|
||||||
if cache_type == "tea":
|
|
||||||
self.model.compute_teacache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
|
||||||
else:
|
|
||||||
self.model.compute_magcache_threshold(self.model.cache_start_step, timesteps, self.model.cache_multiplier)
|
|
||||||
self.model.accumulated_err, self.model.accumulated_steps, self.model.accumulated_ratio = [0.0] * x_count, [0] * x_count, [1.0] * x_count
|
|
||||||
self.model.one_for_all = x_count > 2
|
|
||||||
|
|
||||||
# self.model.to(self.device)
|
|
||||||
if callback != None:
|
|
||||||
callback(-1, None, True)
|
|
||||||
latent = latent.to(self.device)
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
|
||||||
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
|
||||||
latent_model_input = latent
|
|
||||||
timestep = [t]
|
|
||||||
|
|
||||||
timestep = torch.stack(timestep).to(self.device)
|
|
||||||
kwargs.update({
|
|
||||||
't' :timestep,
|
|
||||||
'current_step' :i,
|
|
||||||
})
|
|
||||||
|
|
||||||
if guide_scale == 1:
|
|
||||||
noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
elif joint_pass:
|
|
||||||
if audio_proj == None:
|
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
|
||||||
[latent_model_input, latent_model_input],
|
|
||||||
context=[context, context_null],
|
|
||||||
**kwargs)
|
|
||||||
else:
|
|
||||||
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = self.model(
|
|
||||||
[latent_model_input, latent_model_input, latent_model_input],
|
|
||||||
context=[context, context, context_null],
|
|
||||||
audio_scale = [audio_scale, None, None ],
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if audio_proj != None:
|
|
||||||
noise_pred_noaudio = self.model(
|
|
||||||
[latent_model_input],
|
|
||||||
x_id=1,
|
|
||||||
context=[context],
|
|
||||||
**kwargs,
|
|
||||||
)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
|
|
||||||
noise_pred_uncond = self.model(
|
|
||||||
[latent_model_input],
|
|
||||||
x_id=1 if audio_scale == None else 2,
|
|
||||||
context=[context_null],
|
|
||||||
**kwargs,
|
|
||||||
)[0]
|
|
||||||
if self._interrupt:
|
|
||||||
return None
|
|
||||||
del latent_model_input
|
|
||||||
|
|
||||||
if guide_scale > 1:
|
|
||||||
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
|
||||||
if cfg_star_switch:
|
|
||||||
positive_flat = noise_pred_cond.view(batch_size, -1)
|
|
||||||
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
|
||||||
|
|
||||||
alpha = optimized_scale(positive_flat,negative_flat)
|
|
||||||
alpha = alpha.view(batch_size, 1, 1, 1)
|
|
||||||
|
|
||||||
if (i <= cfg_zero_step):
|
|
||||||
noise_pred = noise_pred_cond*0. # it would be faster not to compute noise_pred...
|
|
||||||
else:
|
|
||||||
noise_pred_uncond *= alpha
|
|
||||||
if audio_scale == None:
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
|
|
||||||
else:
|
|
||||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
|
||||||
|
|
||||||
noise_pred_uncond, noise_pred_noaudio = None, None
|
|
||||||
temp_x0 = sample_scheduler.step(
|
|
||||||
noise_pred.unsqueeze(0),
|
|
||||||
t,
|
|
||||||
latent.unsqueeze(0),
|
|
||||||
return_dict=False,
|
|
||||||
generator=seed_g)[0]
|
|
||||||
latent = temp_x0.squeeze(0)
|
|
||||||
del temp_x0
|
|
||||||
del timestep
|
|
||||||
|
|
||||||
if callback is not None:
|
|
||||||
callback(i, latent, False)
|
|
||||||
|
|
||||||
x0 = [latent]
|
|
||||||
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
|
||||||
|
|
||||||
if any_end_frame and add_frames_for_end_image:
|
|
||||||
# video[:, -1:] = img_interpolated2
|
|
||||||
video = video[:, :-1]
|
|
||||||
|
|
||||||
del noise, latent
|
|
||||||
del sample_scheduler
|
|
||||||
|
|
||||||
return video
|
|
||||||
@ -531,7 +531,7 @@ class CLIPModel:
|
|||||||
seq_len=self.model.max_text_len - 2,
|
seq_len=self.model.max_text_len - 2,
|
||||||
clean='whitespace')
|
clean='whitespace')
|
||||||
|
|
||||||
def visual(self, videos):
|
def visual(self, videos,):
|
||||||
# preprocess
|
# preprocess
|
||||||
size = (self.model.image_size,) * 2
|
size = (self.model.image_size,) * 2
|
||||||
videos = torch.cat([
|
videos = torch.cat([
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from typing import Union,Optional
|
|||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
from .attention import pay_attention
|
from .attention import pay_attention
|
||||||
from torch.backends.cuda import sdp_kernel
|
from torch.backends.cuda import sdp_kernel
|
||||||
|
from wan.multitalk.multitalk_utils import get_attn_map_with_target
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -175,7 +176,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, xlist, grid_sizes, freqs, block_mask = None):
|
def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||||
@ -190,7 +191,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
# query, key, value function
|
# query, key, value function
|
||||||
q = self.q(x)
|
q = self.q(x)
|
||||||
self.norm_q(q)
|
self.norm_q(q)
|
||||||
q = q.view(b, s, n, d) # !!!
|
q = q.view(b, s, n, d)
|
||||||
k = self.k(x)
|
k = self.k(x)
|
||||||
self.norm_k(k)
|
self.norm_k(k)
|
||||||
k = k.view(b, s, n, d)
|
k = k.view(b, s, n, d)
|
||||||
@ -200,6 +201,12 @@ class WanSelfAttention(nn.Module):
|
|||||||
del q,k
|
del q,k
|
||||||
|
|
||||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||||
|
|
||||||
|
if ref_target_masks != None:
|
||||||
|
x_ref_attn_map = get_attn_map_with_target(q, k , grid_sizes, ref_target_masks=ref_target_masks, ref_images_count = ref_images_count)
|
||||||
|
else:
|
||||||
|
x_ref_attn_map = None
|
||||||
|
|
||||||
chipmunk = offload.shared_state.get("_chipmunk", False)
|
chipmunk = offload.shared_state.get("_chipmunk", False)
|
||||||
if chipmunk and self.__class__ == WanSelfAttention:
|
if chipmunk and self.__class__ == WanSelfAttention:
|
||||||
q = q.transpose(1,2)
|
q = q.transpose(1,2)
|
||||||
@ -225,30 +232,10 @@ class WanSelfAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
del q,k,v
|
del q,k,v
|
||||||
|
|
||||||
# if not self._flag_ar_attention:
|
|
||||||
# q = rope_apply(q, grid_sizes, freqs)
|
|
||||||
# k = rope_apply(k, grid_sizes, freqs)
|
|
||||||
# x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
|
|
||||||
# else:
|
|
||||||
# q = rope_apply(q, grid_sizes, freqs)
|
|
||||||
# k = rope_apply(k, grid_sizes, freqs)
|
|
||||||
# q = q.to(torch.bfloat16)
|
|
||||||
# k = k.to(torch.bfloat16)
|
|
||||||
# v = v.to(torch.bfloat16)
|
|
||||||
|
|
||||||
# with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
|
||||||
# x = (
|
|
||||||
# torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
# q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
|
|
||||||
# )
|
|
||||||
# .transpose(1, 2)
|
|
||||||
# .contiguous()
|
|
||||||
# )
|
|
||||||
|
|
||||||
# output
|
|
||||||
x = x.flatten(2)
|
x = x.flatten(2)
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x
|
return x, x_ref_attn_map
|
||||||
|
|
||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
@ -375,7 +362,11 @@ class WanAttentionBlock(nn.Module):
|
|||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
block_id=None,
|
block_id=None,
|
||||||
block_no = 0
|
block_no = 0,
|
||||||
|
output_dim=0,
|
||||||
|
norm_input_visual=True,
|
||||||
|
class_range=24,
|
||||||
|
class_interval=4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -409,6 +400,22 @@ class WanAttentionBlock(nn.Module):
|
|||||||
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
|
|
||||||
|
if output_dim > 0:
|
||||||
|
from wan.multitalk.attention import SingleStreamMutiAttention
|
||||||
|
# init audio module
|
||||||
|
self.audio_cross_attn = SingleStreamMutiAttention(
|
||||||
|
dim=dim,
|
||||||
|
encoder_hidden_states_dim=output_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qk_norm=False,
|
||||||
|
qkv_bias=True,
|
||||||
|
eps=eps,
|
||||||
|
norm_layer=WanRMSNorm,
|
||||||
|
class_range=class_range,
|
||||||
|
class_interval=class_interval
|
||||||
|
)
|
||||||
|
self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -423,6 +430,9 @@ class WanAttentionBlock(nn.Module):
|
|||||||
audio_proj= None,
|
audio_proj= None,
|
||||||
audio_context_lens= None,
|
audio_context_lens= None,
|
||||||
audio_scale=None,
|
audio_scale=None,
|
||||||
|
multitalk_audio=None,
|
||||||
|
multitalk_masks=None,
|
||||||
|
ref_images_count=0,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -466,11 +476,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
xlist = [x_mod.to(attention_dtype)]
|
xlist = [x_mod.to(attention_dtype)]
|
||||||
del x_mod
|
del x_mod
|
||||||
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
|
y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count)
|
||||||
y = y.to(dtype)
|
y = y.to(dtype)
|
||||||
|
|
||||||
if cam_emb != None:
|
if cam_emb != None: y = self.projector(y)
|
||||||
y = self.projector(y)
|
|
||||||
|
|
||||||
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
||||||
x.addcmul_(y, e[2])
|
x.addcmul_(y, e[2])
|
||||||
@ -482,6 +491,25 @@ class WanAttentionBlock(nn.Module):
|
|||||||
del y
|
del y
|
||||||
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
||||||
|
|
||||||
|
if multitalk_audio != None:
|
||||||
|
# cross attn of multitalk audio
|
||||||
|
y = self.norm_x(x)
|
||||||
|
y = y.to(attention_dtype)
|
||||||
|
if ref_images_count == 0:
|
||||||
|
x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
|
||||||
|
else:
|
||||||
|
y_shape = y.shape
|
||||||
|
y = y.reshape(y_shape[0], grid_sizes[0], -1)
|
||||||
|
y = y[:, ref_images_count:]
|
||||||
|
y = y.reshape(y_shape[0], -1, y_shape[-1])
|
||||||
|
grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
|
||||||
|
y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
|
||||||
|
y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
|
||||||
|
x = x.reshape(y_shape[0], grid_sizes[0], -1)
|
||||||
|
x[:, ref_images_count:] += y
|
||||||
|
x = x.reshape(y_shape[0], -1, y_shape[-1])
|
||||||
|
del y
|
||||||
|
|
||||||
y = self.norm2(x)
|
y = self.norm2(x)
|
||||||
|
|
||||||
y = reshape_latent(y , latent_frames)
|
y = reshape_latent(y , latent_frames)
|
||||||
@ -518,6 +546,71 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x.add_(hint, alpha= scale)
|
x.add_(hint, alpha= scale)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class AudioProjModel(ModelMixin, ConfigMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_len=5,
|
||||||
|
seq_len_vf=12,
|
||||||
|
blocks=12,
|
||||||
|
channels=768,
|
||||||
|
intermediate_dim=512,
|
||||||
|
output_dim=768,
|
||||||
|
context_tokens=32,
|
||||||
|
norm_output_audio=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.blocks = blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.input_dim = seq_len * blocks * channels
|
||||||
|
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.context_tokens = context_tokens
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
# define multiple linear layers
|
||||||
|
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
||||||
|
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
||||||
|
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
||||||
|
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
||||||
|
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, audio_embeds, audio_embeds_vf):
|
||||||
|
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||||
|
B, _, _, S, C = audio_embeds.shape
|
||||||
|
|
||||||
|
# process audio of first frame
|
||||||
|
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||||
|
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||||
|
|
||||||
|
# process audio of latter frame
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||||
|
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||||
|
|
||||||
|
# first projection
|
||||||
|
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||||
|
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||||
|
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||||
|
audio_embeds_vf = audio_embeds = None
|
||||||
|
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||||
|
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||||
|
|
||||||
|
# second projection
|
||||||
|
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||||
|
|
||||||
|
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
||||||
|
audio_embeds_c = None
|
||||||
|
# normalization and reshape
|
||||||
|
context_tokens = self.norm(context_tokens)
|
||||||
|
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||||
|
|
||||||
|
return context_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
@ -595,19 +688,27 @@ class Head(nn.Module):
|
|||||||
|
|
||||||
class MLPProj(torch.nn.Module):
|
class MLPProj(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_dim, out_dim):
|
def __init__(self, in_dim, out_dim, flf_pos_emb=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
||||||
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
||||||
torch.nn.LayerNorm(out_dim))
|
torch.nn.LayerNorm(out_dim))
|
||||||
|
|
||||||
|
if flf_pos_emb: # NOTE: we only use this for `flf2v`
|
||||||
|
FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
|
||||||
|
self.emb_pos = nn.Parameter(
|
||||||
|
torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
|
||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
|
if hasattr(self, 'emb_pos'):
|
||||||
|
bs, n, d = image_embeds.shape
|
||||||
|
image_embeds = image_embeds.view(-1, 2 * n, d)
|
||||||
|
image_embeds = image_embeds + self.emb_pos
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
|
|
||||||
|
|
||||||
class WanModel(ModelMixin, ConfigMixin):
|
class WanModel(ModelMixin, ConfigMixin):
|
||||||
def setup_chipmunk(self):
|
def setup_chipmunk(self):
|
||||||
# from chipmunk.util import LayerCounter
|
# from chipmunk.util import LayerCounter
|
||||||
@ -696,45 +797,18 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
|
flf = False,
|
||||||
recammaster = False,
|
recammaster = False,
|
||||||
inject_sample_info = False,
|
inject_sample_info = False,
|
||||||
fantasytalking_dim = 0,
|
fantasytalking_dim = 0,
|
||||||
|
multitalk_output_dim = 0,
|
||||||
|
audio_window=5,
|
||||||
|
intermediate_dim=512,
|
||||||
|
context_tokens=32,
|
||||||
|
vae_scale=4, # vae timedownsample scale
|
||||||
|
norm_input_visual=True,
|
||||||
|
norm_output_audio=True,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Initialize the diffusion model backbone.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_type (`str`, *optional*, defaults to 't2v'):
|
|
||||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
|
||||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
|
||||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
|
||||||
text_len (`int`, *optional*, defaults to 512):
|
|
||||||
Fixed length for text embeddings
|
|
||||||
in_dim (`int`, *optional*, defaults to 16):
|
|
||||||
Input video channels (C_in)
|
|
||||||
dim (`int`, *optional*, defaults to 2048):
|
|
||||||
Hidden dimension of the transformer
|
|
||||||
ffn_dim (`int`, *optional*, defaults to 8192):
|
|
||||||
Intermediate dimension in feed-forward network
|
|
||||||
freq_dim (`int`, *optional*, defaults to 256):
|
|
||||||
Dimension for sinusoidal time embeddings
|
|
||||||
text_dim (`int`, *optional*, defaults to 4096):
|
|
||||||
Input dimension for text embeddings
|
|
||||||
out_dim (`int`, *optional*, defaults to 16):
|
|
||||||
Output video channels (C_out)
|
|
||||||
num_heads (`int`, *optional*, defaults to 16):
|
|
||||||
Number of attention heads
|
|
||||||
num_layers (`int`, *optional*, defaults to 32):
|
|
||||||
Number of transformer blocks
|
|
||||||
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
|
||||||
Window size for local attention (-1 indicates global attention)
|
|
||||||
qk_norm (`bool`, *optional*, defaults to True):
|
|
||||||
Enable query/key normalization
|
|
||||||
cross_attn_norm (`bool`, *optional*, defaults to False):
|
|
||||||
Enable cross-attention normalization
|
|
||||||
eps (`float`, *optional*, defaults to 1e-6):
|
|
||||||
Epsilon value for normalization layers
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -760,6 +834,14 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.block_mask = None
|
self.block_mask = None
|
||||||
self.inject_sample_info = inject_sample_info
|
self.inject_sample_info = inject_sample_info
|
||||||
|
|
||||||
|
self.norm_output_audio = norm_output_audio
|
||||||
|
self.audio_window = audio_window
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.vae_scale = vae_scale
|
||||||
|
|
||||||
|
multitalk = multitalk_output_dim > 0
|
||||||
|
self.multitalk = multitalk
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
self.patch_embedding = nn.Conv3d(
|
self.patch_embedding = nn.Conv3d(
|
||||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||||
@ -780,7 +862,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
window_size, qk_norm, cross_attn_norm, eps, block_no =i)
|
window_size, qk_norm, cross_attn_norm, eps, block_no =i, output_dim=multitalk_output_dim, norm_input_visual=norm_input_visual)
|
||||||
for i in range(num_layers)
|
for i in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -790,7 +872,18 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
||||||
|
|
||||||
if model_type == 'i2v':
|
if model_type == 'i2v':
|
||||||
self.img_emb = MLPProj(1280, dim)
|
self.img_emb = MLPProj(1280, dim, flf_pos_emb = flf)
|
||||||
|
|
||||||
|
if multitalk :
|
||||||
|
# init audio adapter
|
||||||
|
self.audio_proj = AudioProjModel(
|
||||||
|
seq_len=audio_window,
|
||||||
|
seq_len_vf=audio_window+vae_scale-1,
|
||||||
|
intermediate_dim=intermediate_dim,
|
||||||
|
output_dim=multitalk_output_dim,
|
||||||
|
context_tokens=context_tokens,
|
||||||
|
norm_output_audio=norm_output_audio,
|
||||||
|
)
|
||||||
|
|
||||||
# initialize weights
|
# initialize weights
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
@ -806,7 +899,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||||
self.cross_attn_norm, self.eps, block_no =i,
|
self.cross_attn_norm, self.eps, block_no =i,
|
||||||
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None,
|
||||||
|
output_dim=multitalk_output_dim,
|
||||||
|
norm_input_visual=norm_input_visual,
|
||||||
|
)
|
||||||
for i in range(self.num_layers)
|
for i in range(self.num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -847,6 +943,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
layer_list2 += [block.norm3]
|
layer_list2 += [block.norm3]
|
||||||
|
|
||||||
|
if hasattr(self, "audio_proj"):
|
||||||
|
for block in self.blocks:
|
||||||
|
layer_list2 += [block.norm_x]
|
||||||
|
|
||||||
if hasattr(self, "fps_embedding"):
|
if hasattr(self, "fps_embedding"):
|
||||||
layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
layer_list2 += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
||||||
|
|
||||||
@ -1006,6 +1106,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
audio_proj=None,
|
audio_proj=None,
|
||||||
audio_context_lens=None,
|
audio_context_lens=None,
|
||||||
audio_scale=None,
|
audio_scale=None,
|
||||||
|
multitalk_audio = None,
|
||||||
|
multitalk_masks = None,
|
||||||
|
ref_images_count = 0,
|
||||||
|
|
||||||
):
|
):
|
||||||
# patch_dtype = self.patch_embedding.weight.dtype
|
# patch_dtype = self.patch_embedding.weight.dtype
|
||||||
@ -1090,6 +1193,21 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
||||||
|
|
||||||
|
if multitalk_audio != None:
|
||||||
|
multitalk_audio_list = []
|
||||||
|
for audio in multitalk_audio:
|
||||||
|
audio = self.audio_proj(*audio)
|
||||||
|
audio = torch.concat(audio.split(1), dim=2).to(context[0])
|
||||||
|
multitalk_audio_list.append(audio)
|
||||||
|
audio = None
|
||||||
|
else:
|
||||||
|
multitalk_audio_list = [None] * len(x_list)
|
||||||
|
|
||||||
|
if multitalk_masks != None:
|
||||||
|
multitalk_masks_list = multitalk_masks
|
||||||
|
else:
|
||||||
|
multitalk_masks_list = [None] * len(x_list)
|
||||||
|
|
||||||
context_list = context
|
context_list = context
|
||||||
if audio_scale != None:
|
if audio_scale != None:
|
||||||
audio_scale_list = audio_scale
|
audio_scale_list = audio_scale
|
||||||
@ -1105,6 +1223,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
block_mask = block_mask,
|
block_mask = block_mask,
|
||||||
audio_proj=audio_proj,
|
audio_proj=audio_proj,
|
||||||
audio_context_lens=audio_context_lens,
|
audio_context_lens=audio_context_lens,
|
||||||
|
ref_images_count=ref_images_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if vace_context == None:
|
if vace_context == None:
|
||||||
@ -1137,7 +1256,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if self.accumulated_err[cur_x_id]<self.magcache_thresh and self.accumulated_steps[cur_x_id]<=self.magcache_K:
|
if self.accumulated_err[cur_x_id]<self.magcache_thresh and self.accumulated_steps[cur_x_id]<=self.magcache_K:
|
||||||
skip_forward = True
|
skip_forward = True
|
||||||
if i == 0 and x_id == 0: self.cache_skipped_steps += 1
|
if i == 0 and x_id == 0: self.cache_skipped_steps += 1
|
||||||
print(f"skip: step={current_step} for x_id={cur_x_id}, accum error {self.accumulated_err[cur_x_id]}")
|
# print(f"skip: step={current_step} for x_id={cur_x_id}, accum error {self.accumulated_err[cur_x_id]}")
|
||||||
else:
|
else:
|
||||||
skip_forward = False
|
skip_forward = False
|
||||||
self.accumulated_err[cur_x_id], self.accumulated_steps[cur_x_id], self.accumulated_ratio[cur_x_id] = 0, 0, 1.0
|
self.accumulated_err[cur_x_id], self.accumulated_steps[cur_x_id], self.accumulated_ratio[cur_x_id] = 0, 0, 1.0
|
||||||
@ -1209,11 +1328,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
continue
|
continue
|
||||||
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
|
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
|
||||||
else:
|
else:
|
||||||
for i, (x, context, hints, audio_scale, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, x_should_calc)):
|
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
|
||||||
if should_calc:
|
if should_calc:
|
||||||
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, e= e0, **kwargs)
|
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
|
||||||
del x
|
del x
|
||||||
context, hints = None, None
|
context = hints = audio_embedding = None
|
||||||
|
|
||||||
if self.enable_cache != None:
|
if self.enable_cache != None:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
|
|||||||
@ -11,7 +11,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('TkAgg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from flask import Flask, request, jsonify, render_template
|
from flask import Flask, request, jsonify, render_template
|
||||||
import os
|
import os
|
||||||
@ -21,7 +22,6 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import argparse
|
import argparse
|
||||||
matplotlib.use('Agg')
|
|
||||||
|
|
||||||
app = Flask(__name__, static_folder='static', template_folder='templates')
|
app = Flask(__name__, static_folder='static', template_folder='templates')
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from PIL import Image
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from rembg import remove, new_session
|
from rembg import remove, new_session
|
||||||
import random
|
import random
|
||||||
|
import ffmpeg
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
||||||
|
|
||||||
@ -33,9 +36,6 @@ def seed_everything(seed: int):
|
|||||||
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
|
||||||
import math
|
import math
|
||||||
|
|
||||||
if video_fps < target_fps :
|
|
||||||
video_fps = target_fps
|
|
||||||
|
|
||||||
video_frame_duration = 1 /video_fps
|
video_frame_duration = 1 /video_fps
|
||||||
target_frame_duration = 1 / target_fps
|
target_frame_duration = 1 / target_fps
|
||||||
|
|
||||||
@ -160,8 +160,8 @@ def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_int
|
|||||||
new_width = round( width * scale / block_size) * block_size
|
new_width = round( width * scale / block_size) * block_size
|
||||||
return new_height, new_width
|
return new_height, new_width
|
||||||
|
|
||||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
|
||||||
if rm_background > 0:
|
if rm_background:
|
||||||
session = new_session()
|
session = new_session()
|
||||||
|
|
||||||
output_list =[]
|
output_list =[]
|
||||||
@ -183,7 +183,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
|||||||
new_height = int( round(height * scale / 16) * 16)
|
new_height = int( round(height * scale / 16) * 16)
|
||||||
new_width = int( round(width * scale / 16) * 16)
|
new_width = int( round(width * scale / 16) * 16)
|
||||||
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
||||||
if rm_background == 1 or rm_background == 2 and i > 0 :
|
if rm_background and not (ignore_first and i == 0) :
|
||||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
||||||
@ -406,3 +406,137 @@ def create_progress_hook(filename):
|
|||||||
return progress_hook(block_num, block_size, total_size, filename)
|
return progress_hook(block_num, block_size, total_size, filename)
|
||||||
return hook
|
return hook
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
def extract_audio_tracks(source_video, verbose=False, query_only= False):
|
||||||
|
"""
|
||||||
|
Extract all audio tracks from source video to temporary files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_video: Path to video with audio to extract
|
||||||
|
verbose: Enable verbose output (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of temporary audio file paths, or empty list if no audio tracks
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if source video has audio
|
||||||
|
probe = ffmpeg.probe(source_video)
|
||||||
|
audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio']
|
||||||
|
|
||||||
|
if not audio_streams:
|
||||||
|
if query_only: return 0
|
||||||
|
if verbose:
|
||||||
|
print(f"No audio track found in {source_video}")
|
||||||
|
return []
|
||||||
|
if query_only: return len(audio_streams)
|
||||||
|
if verbose:
|
||||||
|
print(f"Found {len(audio_streams)} audio track(s)")
|
||||||
|
|
||||||
|
# Create temporary audio files for each track
|
||||||
|
temp_audio_files = []
|
||||||
|
for i in range(len(audio_streams)):
|
||||||
|
fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_')
|
||||||
|
os.close(fd) # Close file descriptor immediately
|
||||||
|
temp_audio_files.append(temp_path)
|
||||||
|
|
||||||
|
# Extract each audio track
|
||||||
|
for i, temp_path in enumerate(temp_audio_files):
|
||||||
|
(ffmpeg
|
||||||
|
.input(source_video)
|
||||||
|
.output(temp_path, **{f'map': f'0:a:{i}', 'acodec': 'aac'})
|
||||||
|
.overwrite_output()
|
||||||
|
.run(quiet=not verbose))
|
||||||
|
|
||||||
|
return temp_audio_files
|
||||||
|
|
||||||
|
except ffmpeg.Error as e:
|
||||||
|
print(f"FFmpeg error during audio extraction: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during audio extraction: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
|
||||||
|
"""
|
||||||
|
Combine video with audio tracks. Output duration matches video length exactly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_video: Path to video to receive the audio
|
||||||
|
audio_tracks: List of audio file paths to combine
|
||||||
|
output_video: Path for the output video
|
||||||
|
verbose: Enable verbose output (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
if not audio_tracks:
|
||||||
|
if verbose:
|
||||||
|
print("No audio tracks to combine")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get video duration to ensure exact alignment
|
||||||
|
video_probe = ffmpeg.probe(target_video)
|
||||||
|
video_duration = float(video_probe['streams'][0]['duration'])
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Target video duration: {video_duration:.3f} seconds")
|
||||||
|
|
||||||
|
# Combine target video with all audio tracks, force video duration
|
||||||
|
video = ffmpeg.input(target_video).video
|
||||||
|
audio_inputs = [ffmpeg.input(audio_path).audio for audio_path in audio_tracks]
|
||||||
|
|
||||||
|
# Create output with video duration as master timing
|
||||||
|
inputs = [video] + audio_inputs
|
||||||
|
(ffmpeg
|
||||||
|
.output(*inputs, output_video,
|
||||||
|
vcodec='copy',
|
||||||
|
acodec='copy',
|
||||||
|
t=video_duration) # Force exact video duration
|
||||||
|
.overwrite_output()
|
||||||
|
.run(quiet=not verbose))
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Successfully created {output_video} with {len(audio_tracks)} audio track(s) aligned to video duration")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ffmpeg.Error as e:
|
||||||
|
print(f"FFmpeg error during video combination: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during video combination: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cleanup_temp_audio_files(audio_tracks, verbose=False):
|
||||||
|
"""
|
||||||
|
Clean up temporary audio files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_tracks: List of audio file paths to delete
|
||||||
|
verbose: Enable verbose output (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files successfully deleted
|
||||||
|
"""
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
for audio_path in audio_tracks:
|
||||||
|
try:
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
os.unlink(audio_path)
|
||||||
|
deleted_count += 1
|
||||||
|
if verbose:
|
||||||
|
print(f"Cleaned up {audio_path}")
|
||||||
|
except PermissionError:
|
||||||
|
print(f"Warning: Could not delete {audio_path} (file may be in use)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error deleting {audio_path}: {e}")
|
||||||
|
|
||||||
|
if verbose and deleted_count > 0:
|
||||||
|
print(f"Successfully deleted {deleted_count} temporary audio file(s)")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user