mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			647 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			647 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import argparse
 | 
						|
import binascii
 | 
						|
import os
 | 
						|
import os.path as osp
 | 
						|
import torchvision.transforms.functional as TF
 | 
						|
import torch.nn.functional as F
 | 
						|
import cv2
 | 
						|
import tempfile
 | 
						|
import imageio
 | 
						|
import torch
 | 
						|
import decord
 | 
						|
import torchvision
 | 
						|
from PIL import Image
 | 
						|
import numpy as np
 | 
						|
from rembg import remove, new_session
 | 
						|
import random
 | 
						|
import ffmpeg
 | 
						|
import os
 | 
						|
import tempfile
 | 
						|
import subprocess
 | 
						|
import json
 | 
						|
 | 
						|
__all__ = ['cache_video', 'cache_image', 'str2bool']
 | 
						|
 | 
						|
 | 
						|
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
def seed_everything(seed: int):
 | 
						|
    random.seed(seed)
 | 
						|
    np.random.seed(seed)
 | 
						|
    torch.manual_seed(seed)
 | 
						|
    if torch.cuda.is_available():
 | 
						|
        torch.cuda.manual_seed(seed)
 | 
						|
    if torch.backends.mps.is_available():
 | 
						|
        torch.mps.manual_seed(seed)
 | 
						|
 | 
						|
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
 | 
						|
    import math
 | 
						|
 | 
						|
    video_frame_duration = 1 /video_fps
 | 
						|
    target_frame_duration = 1 / target_fps 
 | 
						|
    
 | 
						|
    target_time = start_target_frame * target_frame_duration
 | 
						|
    frame_no = math.ceil(target_time / video_frame_duration)  
 | 
						|
    cur_time = frame_no * video_frame_duration
 | 
						|
    frame_ids =[]
 | 
						|
    while True:
 | 
						|
        if max_target_frames_count != 0 and len(frame_ids) >= max_target_frames_count :
 | 
						|
            break
 | 
						|
        diff = round( (target_time -cur_time) / video_frame_duration , 5)
 | 
						|
        add_frames_count = math.ceil( diff)
 | 
						|
        frame_no += add_frames_count
 | 
						|
        if frame_no >= video_frames_count:             
 | 
						|
            break
 | 
						|
        frame_ids.append(frame_no)
 | 
						|
        cur_time += add_frames_count * video_frame_duration
 | 
						|
        target_time += target_frame_duration
 | 
						|
    frame_ids = frame_ids[:max_target_frames_count]
 | 
						|
    return frame_ids
 | 
						|
 | 
						|
import os
 | 
						|
from datetime import datetime
 | 
						|
 | 
						|
def get_file_creation_date(file_path):
 | 
						|
    # On Windows
 | 
						|
    if os.name == 'nt':
 | 
						|
        return datetime.fromtimestamp(os.path.getctime(file_path))
 | 
						|
    # On Unix/Linux/Mac (gets last status change, not creation)
 | 
						|
    else:
 | 
						|
        stat = os.stat(file_path)
 | 
						|
    return datetime.fromtimestamp(stat.st_birthtime if hasattr(stat, 'st_birthtime') else stat.st_mtime)
 | 
						|
 | 
						|
def truncate_for_filesystem(s, max_bytes=255):
 | 
						|
    if len(s.encode('utf-8')) <= max_bytes: return s
 | 
						|
    l, r = 0, len(s)
 | 
						|
    while l < r:
 | 
						|
        m = (l + r + 1) // 2
 | 
						|
        if len(s[:m].encode('utf-8')) <= max_bytes: l = m
 | 
						|
        else: r = m - 1
 | 
						|
    return s[:l]
 | 
						|
 | 
						|
def get_video_info(video_path):
 | 
						|
    import cv2
 | 
						|
    cap = cv2.VideoCapture(video_path)
 | 
						|
    
 | 
						|
    # Get FPS
 | 
						|
    fps = round(cap.get(cv2.CAP_PROP_FPS))
 | 
						|
    
 | 
						|
    # Get resolution
 | 
						|
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 | 
						|
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 | 
						|
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 
 | 
						|
    cap.release()
 | 
						|
    
 | 
						|
    return fps, width, height, frame_count
 | 
						|
 | 
						|
def get_video_frame(file_name, frame_no):
 | 
						|
    decord.bridge.set_bridge('torch')
 | 
						|
    reader = decord.VideoReader(file_name)
 | 
						|
 | 
						|
    frame = reader.get_batch([frame_no]).squeeze(0)
 | 
						|
    img = Image.fromarray(frame.numpy().astype(np.uint8))
 | 
						|
    return img
 | 
						|
 | 
						|
def convert_image_to_video(image):
 | 
						|
    if image is None:
 | 
						|
        return None
 | 
						|
    
 | 
						|
    # Convert PIL/numpy image to OpenCV format if needed
 | 
						|
    if isinstance(image, np.ndarray):
 | 
						|
        # Gradio images are typically RGB, OpenCV expects BGR
 | 
						|
        img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 | 
						|
    else:
 | 
						|
        # Handle PIL Image
 | 
						|
        img_array = np.array(image)
 | 
						|
        img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
 | 
						|
    
 | 
						|
    height, width = img_bgr.shape[:2]
 | 
						|
    
 | 
						|
    # Create temporary video file (auto-cleaned by Gradio)
 | 
						|
    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_video:
 | 
						|
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 | 
						|
        out = cv2.VideoWriter(temp_video.name, fourcc, 30.0, (width, height))
 | 
						|
        out.write(img_bgr)
 | 
						|
        out.release()
 | 
						|
        return temp_video.name
 | 
						|
    
 | 
						|
def resize_lanczos(img, h, w):
 | 
						|
    img = (img + 1).float().mul_(127.5)
 | 
						|
    img = Image.fromarray(np.clip(img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
 | 
						|
    img = img.resize((w,h), resample=Image.Resampling.LANCZOS) 
 | 
						|
    img = torch.from_numpy(np.array(img).astype(np.float32)).movedim(-1, 0)
 | 
						|
    img = img.div(127.5).sub_(1)
 | 
						|
    return img
 | 
						|
 | 
						|
def remove_background(img, session=None):
 | 
						|
    if session ==None:
 | 
						|
        session = new_session() 
 | 
						|
    img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
 | 
						|
    img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
						|
    return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
 | 
						|
 | 
						|
def convert_tensor_to_image(t, frame_no = -1):    
 | 
						|
    t = t[:, frame_no] if frame_no >= 0 else t
 | 
						|
    return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
 | 
						|
 | 
						|
def save_image(tensor_image, name, frame_no = -1):
 | 
						|
    convert_tensor_to_image(tensor_image, frame_no).save(name)
 | 
						|
 | 
						|
def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims):
 | 
						|
    outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
						|
    frame_height = int(frame_height * (100 + outpainting_top + outpainting_bottom) / 100)
 | 
						|
    frame_width =  int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
 | 
						|
    return frame_height, frame_width  
 | 
						|
 | 
						|
def  get_outpainting_frame_location(final_height, final_width,  outpainting_dims, block_size = 8):
 | 
						|
    outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims
 | 
						|
    raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100))
 | 
						|
    height = int(raw_height / block_size) * block_size
 | 
						|
    extra_height = raw_height - height
 | 
						|
          
 | 
						|
    raw_width = int(final_width / ((100 + outpainting_left + outpainting_right) / 100)) 
 | 
						|
    width = int(raw_width / block_size) * block_size
 | 
						|
    extra_width = raw_width - width  
 | 
						|
    margin_top = int(outpainting_top/(100 + outpainting_top + outpainting_bottom) * final_height)
 | 
						|
    if extra_height != 0 and (outpainting_top + outpainting_bottom) != 0:
 | 
						|
        margin_top += int(outpainting_top / (outpainting_top + outpainting_bottom) * extra_height)
 | 
						|
    if (margin_top + height) > final_height or outpainting_bottom == 0: margin_top = final_height - height
 | 
						|
    margin_left = int(outpainting_left/(100 + outpainting_left + outpainting_right) * final_width)
 | 
						|
    if extra_width != 0 and (outpainting_left + outpainting_right) != 0:
 | 
						|
        margin_left += int(outpainting_left / (outpainting_left + outpainting_right) * extra_height)
 | 
						|
    if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
 | 
						|
    return height, width, margin_top, margin_left
 | 
						|
 | 
						|
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
 | 
						|
    if fit_into_canvas == None:
 | 
						|
        return height, width
 | 
						|
    if fit_into_canvas:
 | 
						|
        scale1  = min(canvas_height / height, canvas_width / width)
 | 
						|
        scale2  = min(canvas_width / height, canvas_height / width)
 | 
						|
        scale = max(scale1, scale2) 
 | 
						|
    else:
 | 
						|
        scale = (canvas_height * canvas_width / (height * width))**(1/2)
 | 
						|
 | 
						|
    new_height = round( height * scale / block_size) * block_size
 | 
						|
    new_width = round( width * scale / block_size) * block_size
 | 
						|
    return new_height, new_width
 | 
						|
 | 
						|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
 | 
						|
    if rm_background:
 | 
						|
        session = new_session() 
 | 
						|
 | 
						|
    output_list =[]
 | 
						|
    for i, img in enumerate(img_list):
 | 
						|
        width, height =  img.size 
 | 
						|
 | 
						|
        if fit_into_canvas:
 | 
						|
            white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 
 | 
						|
            scale = min(budget_height / height, budget_width / width)
 | 
						|
            new_height = int(height * scale)
 | 
						|
            new_width = int(width * scale)
 | 
						|
            resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) 
 | 
						|
            top = (budget_height - new_height) // 2
 | 
						|
            left = (budget_width - new_width) // 2
 | 
						|
            white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)            
 | 
						|
            resized_image = Image.fromarray(white_canvas)  
 | 
						|
        else:
 | 
						|
            scale = (budget_height * budget_width / (height * width))**(1/2)
 | 
						|
            new_height = int( round(height * scale / 16) * 16)
 | 
						|
            new_width = int( round(width * scale / 16) * 16)
 | 
						|
            resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) 
 | 
						|
        if rm_background  and not (ignore_first and i == 0) :
 | 
						|
            # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
						|
            resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
						|
        output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
 | 
						|
    return output_list
 | 
						|
 | 
						|
 | 
						|
def rand_name(length=8, suffix=''):
 | 
						|
    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 | 
						|
    if suffix:
 | 
						|
        if not suffix.startswith('.'):
 | 
						|
            suffix = '.' + suffix
 | 
						|
        name += suffix
 | 
						|
    return name
 | 
						|
 | 
						|
 | 
						|
def cache_video(tensor,
 | 
						|
                save_file=None,
 | 
						|
                fps=30,
 | 
						|
                suffix='.mp4',
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    # cache file
 | 
						|
    cache_file = osp.join('/tmp', rand_name(
 | 
						|
        suffix=suffix)) if save_file is None else save_file
 | 
						|
 | 
						|
    # save to cache
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            # preprocess
 | 
						|
            tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
            tensor = torch.stack([
 | 
						|
                torchvision.utils.make_grid(
 | 
						|
                    u, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
						|
                for u in tensor.unbind(2)
 | 
						|
            ],
 | 
						|
                                 dim=1).permute(1, 2, 3, 0)
 | 
						|
            tensor = (tensor * 255).type(torch.uint8).cpu()
 | 
						|
 | 
						|
            # write video
 | 
						|
            writer = imageio.get_writer(
 | 
						|
                cache_file, fps=fps, codec='libx264', quality=8)
 | 
						|
            for frame in tensor.numpy():
 | 
						|
                writer.append_data(frame)
 | 
						|
            writer.close()
 | 
						|
            return cache_file
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            continue
 | 
						|
    else:
 | 
						|
        print(f'cache_video failed, error: {error}', flush=True)
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
def cache_image(tensor,
 | 
						|
                save_file,
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    # cache file
 | 
						|
    suffix = osp.splitext(save_file)[1]
 | 
						|
    if suffix.lower() not in [
 | 
						|
            '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
 | 
						|
    ]:
 | 
						|
        suffix = '.png'
 | 
						|
 | 
						|
    # save to cache
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
            torchvision.utils.save_image(
 | 
						|
                tensor,
 | 
						|
                save_file,
 | 
						|
                nrow=nrow,
 | 
						|
                normalize=normalize,
 | 
						|
                value_range=value_range)
 | 
						|
            return save_file
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            continue
 | 
						|
 | 
						|
 | 
						|
def str2bool(v):
 | 
						|
    """
 | 
						|
    Convert a string to a boolean.
 | 
						|
 | 
						|
    Supported true values: 'yes', 'true', 't', 'y', '1'
 | 
						|
    Supported false values: 'no', 'false', 'f', 'n', '0'
 | 
						|
 | 
						|
    Args:
 | 
						|
        v (str): String to convert.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        bool: Converted boolean value.
 | 
						|
 | 
						|
    Raises:
 | 
						|
        argparse.ArgumentTypeError: If the value cannot be converted to boolean.
 | 
						|
    """
 | 
						|
    if isinstance(v, bool):
 | 
						|
        return v
 | 
						|
    v_lower = v.lower()
 | 
						|
    if v_lower in ('yes', 'true', 't', 'y', '1'):
 | 
						|
        return True
 | 
						|
    elif v_lower in ('no', 'false', 'f', 'n', '0'):
 | 
						|
        return False
 | 
						|
    else:
 | 
						|
        raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
 | 
						|
 | 
						|
 | 
						|
import sys, time
 | 
						|
 | 
						|
# Global variables to track download progress
 | 
						|
_start_time = None
 | 
						|
_last_time = None
 | 
						|
_last_downloaded = 0
 | 
						|
_speed_history = []
 | 
						|
_update_interval = 0.5  # Update speed every 0.5 seconds
 | 
						|
 | 
						|
def progress_hook(block_num, block_size, total_size, filename=None):
 | 
						|
    """
 | 
						|
    Simple progress bar hook for urlretrieve
 | 
						|
    
 | 
						|
    Args:
 | 
						|
        block_num: Number of blocks downloaded so far
 | 
						|
        block_size: Size of each block in bytes
 | 
						|
        total_size: Total size of the file in bytes
 | 
						|
        filename: Name of the file being downloaded (optional)
 | 
						|
    """
 | 
						|
    global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
 | 
						|
    
 | 
						|
    current_time = time.time()
 | 
						|
    downloaded = block_num * block_size
 | 
						|
    
 | 
						|
    # Initialize timing on first call
 | 
						|
    if _start_time is None or block_num == 0:
 | 
						|
        _start_time = current_time
 | 
						|
        _last_time = current_time
 | 
						|
        _last_downloaded = 0
 | 
						|
        _speed_history = []
 | 
						|
    
 | 
						|
    # Calculate download speed only at specified intervals
 | 
						|
    speed = 0
 | 
						|
    if current_time - _last_time >= _update_interval:
 | 
						|
        if _last_time > 0:
 | 
						|
            current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
 | 
						|
            _speed_history.append(current_speed)
 | 
						|
            # Keep only last 5 speed measurements for smoothing
 | 
						|
            if len(_speed_history) > 5:
 | 
						|
                _speed_history.pop(0)
 | 
						|
            # Average the recent speeds for smoother display
 | 
						|
            speed = sum(_speed_history) / len(_speed_history)
 | 
						|
        
 | 
						|
        _last_time = current_time
 | 
						|
        _last_downloaded = downloaded
 | 
						|
    elif _speed_history:
 | 
						|
        # Use the last calculated average speed
 | 
						|
        speed = sum(_speed_history) / len(_speed_history)
 | 
						|
    # Format file sizes and speed
 | 
						|
    def format_bytes(bytes_val):
 | 
						|
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
						|
            if bytes_val < 1024:
 | 
						|
                return f"{bytes_val:.1f}{unit}"
 | 
						|
            bytes_val /= 1024
 | 
						|
        return f"{bytes_val:.1f}TB"
 | 
						|
    
 | 
						|
    file_display = filename if filename else "Unknown file"
 | 
						|
    
 | 
						|
    if total_size <= 0:
 | 
						|
        # If total size is unknown, show downloaded bytes
 | 
						|
        speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
						|
        line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
 | 
						|
        # Clear any trailing characters by padding with spaces
 | 
						|
        sys.stdout.write(line.ljust(80))
 | 
						|
        sys.stdout.flush()
 | 
						|
        return
 | 
						|
    
 | 
						|
    downloaded = block_num * block_size
 | 
						|
    percent = min(100, (downloaded / total_size) * 100)
 | 
						|
    
 | 
						|
    # Create progress bar (40 characters wide to leave room for other info)
 | 
						|
    bar_length = 40
 | 
						|
    filled = int(bar_length * percent / 100)
 | 
						|
    bar = '█' * filled + '░' * (bar_length - filled)
 | 
						|
    
 | 
						|
    # Format file sizes and speed
 | 
						|
    def format_bytes(bytes_val):
 | 
						|
        for unit in ['B', 'KB', 'MB', 'GB']:
 | 
						|
            if bytes_val < 1024:
 | 
						|
                return f"{bytes_val:.1f}{unit}"
 | 
						|
            bytes_val /= 1024
 | 
						|
        return f"{bytes_val:.1f}TB"
 | 
						|
    
 | 
						|
    speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
 | 
						|
    
 | 
						|
    # Display progress with filename first
 | 
						|
    line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
 | 
						|
    # Clear any trailing characters by padding with spaces
 | 
						|
    sys.stdout.write(line.ljust(100))
 | 
						|
    sys.stdout.flush()
 | 
						|
    
 | 
						|
    # Print newline when complete
 | 
						|
    if percent >= 100:
 | 
						|
        print()
 | 
						|
 | 
						|
# Wrapper function to include filename in progress hook
 | 
						|
def create_progress_hook(filename):
 | 
						|
    """Creates a progress hook with the filename included"""
 | 
						|
    global _start_time, _last_time, _last_downloaded, _speed_history
 | 
						|
    # Reset timing variables for new download
 | 
						|
    _start_time = None
 | 
						|
    _last_time = None
 | 
						|
    _last_downloaded = 0
 | 
						|
    _speed_history = []
 | 
						|
    
 | 
						|
    def hook(block_num, block_size, total_size):
 | 
						|
        return progress_hook(block_num, block_size, total_size, filename)
 | 
						|
    return hook
 | 
						|
 | 
						|
 | 
						|
import tempfile, os
 | 
						|
import ffmpeg
 | 
						|
 | 
						|
def extract_audio_tracks(source_video, verbose=False, query_only=False):
 | 
						|
    """
 | 
						|
    Extract all audio tracks from a source video into temporary AAC files.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tuple:
 | 
						|
          - List of temp file paths for extracted audio tracks
 | 
						|
          - List of corresponding metadata dicts:
 | 
						|
              {'codec', 'sample_rate', 'channels', 'duration', 'language'}
 | 
						|
              where 'duration' is set to container duration (for consistency).
 | 
						|
    """
 | 
						|
    probe = ffmpeg.probe(source_video)
 | 
						|
    audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio']
 | 
						|
    container_duration = float(probe['format'].get('duration', 0.0))
 | 
						|
 | 
						|
    if not audio_streams:
 | 
						|
        if query_only: return 0
 | 
						|
        if verbose: print(f"No audio track found in {source_video}")
 | 
						|
        return [], []
 | 
						|
 | 
						|
    if query_only:
 | 
						|
        return len(audio_streams)
 | 
						|
 | 
						|
    if verbose:
 | 
						|
        print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s")
 | 
						|
 | 
						|
    file_paths = []
 | 
						|
    metadata = []
 | 
						|
 | 
						|
    for i, stream in enumerate(audio_streams):
 | 
						|
        fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_')
 | 
						|
        os.close(fd)
 | 
						|
 | 
						|
        file_paths.append(temp_path)
 | 
						|
        metadata.append({
 | 
						|
            'codec': stream.get('codec_name'),
 | 
						|
            'sample_rate': int(stream.get('sample_rate', 0)),
 | 
						|
            'channels': int(stream.get('channels', 0)),
 | 
						|
            'duration': container_duration,
 | 
						|
            'language': stream.get('tags', {}).get('language', None)
 | 
						|
        })
 | 
						|
 | 
						|
        ffmpeg.input(source_video).output(
 | 
						|
            temp_path,
 | 
						|
            **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'}
 | 
						|
        ).overwrite_output().run(quiet=not verbose)
 | 
						|
 | 
						|
    return file_paths, metadata
 | 
						|
 | 
						|
 | 
						|
import subprocess
 | 
						|
 | 
						|
import subprocess
 | 
						|
 | 
						|
def combine_and_concatenate_video_with_audio_tracks(
 | 
						|
    save_path_tmp, video_path,
 | 
						|
    source_audio_tracks, new_audio_tracks,
 | 
						|
    source_audio_duration, audio_sampling_rate,
 | 
						|
    new_audio_from_start=False,
 | 
						|
    source_audio_metadata=None,
 | 
						|
    audio_bitrate='128k',
 | 
						|
    audio_codec='aac',
 | 
						|
    verbose = False
 | 
						|
):
 | 
						|
    inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1
 | 
						|
    metadata_args = []
 | 
						|
    sources = source_audio_tracks or []
 | 
						|
    news = new_audio_tracks or []
 | 
						|
 | 
						|
    duplicate_source = len(sources) == 1 and len(news) > 1
 | 
						|
    N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1
 | 
						|
 | 
						|
    for i in range(N):
 | 
						|
        s = (sources[i] if i < len(sources)
 | 
						|
             else sources[0] if duplicate_source else None)
 | 
						|
        n = news[i] if len(news) == N else (news[0] if news else None)
 | 
						|
 | 
						|
        if source_audio_duration == 0:
 | 
						|
            if n:
 | 
						|
                inputs += ['-i', n]
 | 
						|
                filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}]')
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}]')
 | 
						|
        else:
 | 
						|
            if s:
 | 
						|
                inputs += ['-i', s]
 | 
						|
                meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {}
 | 
						|
                needs_filter = (
 | 
						|
                    meta.get('codec') != audio_codec or
 | 
						|
                    meta.get('sample_rate') != audio_sampling_rate or
 | 
						|
                    meta.get('channels') != 1 or
 | 
						|
                    meta.get('duration', 0) < source_audio_duration
 | 
						|
                )
 | 
						|
                if needs_filter:
 | 
						|
                    filters.append(
 | 
						|
                        f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,'
 | 
						|
                        f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
                else:
 | 
						|
                    filters.append(
 | 
						|
                        f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
                if lang := meta.get('language'):
 | 
						|
                    metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}']
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(
 | 
						|
                    f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
 | 
						|
            if n:
 | 
						|
                inputs += ['-i', n]
 | 
						|
                start = '0' if new_audio_from_start else source_audio_duration
 | 
						|
                filters.append(
 | 
						|
                    f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,'
 | 
						|
                    f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}]')
 | 
						|
                filters.append(f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}]')
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(f'[s{i}]apad=pad_dur=100[aout{i}]')
 | 
						|
 | 
						|
        maps += ['-map', f'[aout{i}]']
 | 
						|
 | 
						|
    cmd = ['ffmpeg', '-y', *inputs,
 | 
						|
           '-filter_complex', ';'.join(filters),  # ✅ Only change made
 | 
						|
           *maps, *metadata_args,
 | 
						|
           '-c:v', 'copy',
 | 
						|
           '-c:a', audio_codec,
 | 
						|
           '-b:a', audio_bitrate,
 | 
						|
           '-ar', str(audio_sampling_rate),
 | 
						|
           '-ac', '1',
 | 
						|
           '-shortest', save_path_tmp]
 | 
						|
 | 
						|
    if verbose:
 | 
						|
        print(f"ffmpeg command: {cmd}")
 | 
						|
    try:
 | 
						|
        subprocess.run(cmd, check=True, capture_output=True, text=True)
 | 
						|
    except subprocess.CalledProcessError as e:
 | 
						|
        raise Exception(f"FFmpeg error: {e.stderr}")
 | 
						|
 | 
						|
 | 
						|
import ffmpeg
 | 
						|
 | 
						|
 | 
						|
import subprocess
 | 
						|
import ffmpeg
 | 
						|
 | 
						|
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video,
 | 
						|
                                     audio_metadata=None, verbose=False):
 | 
						|
    if not audio_tracks:
 | 
						|
        if verbose: print("No audio tracks to combine."); return False
 | 
						|
 | 
						|
    dur = float(next(s for s in ffmpeg.probe(target_video)['streams']
 | 
						|
                     if s['codec_type'] == 'video')['duration'])
 | 
						|
    if verbose: print(f"Video duration: {dur:.3f}s")
 | 
						|
 | 
						|
    cmd = ['ffmpeg', '-y', '-i', target_video]
 | 
						|
    for path in audio_tracks:
 | 
						|
        cmd += ['-i', path]
 | 
						|
 | 
						|
    cmd += ['-map', '0:v']
 | 
						|
    for i in range(len(audio_tracks)):
 | 
						|
        cmd += ['-map', f'{i+1}:a']
 | 
						|
 | 
						|
    for i, meta in enumerate(audio_metadata or []):
 | 
						|
        if (lang := meta.get('language')):
 | 
						|
            cmd += ['-metadata:s:a:' + str(i), f'language={lang}']
 | 
						|
 | 
						|
    cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video]
 | 
						|
 | 
						|
    result = subprocess.run(cmd, capture_output=not verbose, text=True)
 | 
						|
    if result.returncode != 0:
 | 
						|
        raise Exception(f"FFmpeg error:\n{result.stderr}")
 | 
						|
    if verbose:
 | 
						|
        print(f"Created {output_video} with {len(audio_tracks)} audio track(s)")
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def cleanup_temp_audio_files(audio_tracks, verbose=False):
 | 
						|
    """
 | 
						|
    Clean up temporary audio files.
 | 
						|
    
 | 
						|
    Args:
 | 
						|
        audio_tracks: List of audio file paths to delete
 | 
						|
        verbose: Enable verbose output (default: False)
 | 
						|
        
 | 
						|
    Returns:
 | 
						|
        Number of files successfully deleted
 | 
						|
    """
 | 
						|
    deleted_count = 0
 | 
						|
    
 | 
						|
    for audio_path in audio_tracks:
 | 
						|
        try:
 | 
						|
            if os.path.exists(audio_path):
 | 
						|
                os.unlink(audio_path)
 | 
						|
                deleted_count += 1
 | 
						|
                if verbose:
 | 
						|
                    print(f"Cleaned up {audio_path}")
 | 
						|
        except PermissionError:
 | 
						|
            print(f"Warning: Could not delete {audio_path} (file may be in use)")
 | 
						|
        except Exception as e:
 | 
						|
            print(f"Warning: Error deleting {audio_path}: {e}")
 | 
						|
    
 | 
						|
    if verbose and deleted_count > 0:
 | 
						|
        print(f"Successfully deleted {deleted_count} temporary audio file(s)")
 | 
						|
    
 | 
						|
    return deleted_count
 | 
						|
 |