mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			421 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			421 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import subprocess
 | 
						|
import tempfile, os
 | 
						|
import ffmpeg
 | 
						|
import torchvision.transforms.functional as TF
 | 
						|
import torch.nn.functional as F
 | 
						|
import cv2
 | 
						|
import tempfile
 | 
						|
import imageio
 | 
						|
import binascii
 | 
						|
import torchvision
 | 
						|
import torch
 | 
						|
from PIL import Image
 | 
						|
import os.path as osp
 | 
						|
import json
 | 
						|
 | 
						|
def rand_name(length=8, suffix=''):
 | 
						|
    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 | 
						|
    if suffix:
 | 
						|
        if not suffix.startswith('.'):
 | 
						|
            suffix = '.' + suffix
 | 
						|
        name += suffix
 | 
						|
    return name
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def extract_audio_tracks(source_video, verbose=False, query_only=False):
 | 
						|
    """
 | 
						|
    Extract all audio tracks from a source video into temporary AAC files.
 | 
						|
 | 
						|
    Returns:
 | 
						|
        Tuple:
 | 
						|
          - List of temp file paths for extracted audio tracks
 | 
						|
          - List of corresponding metadata dicts:
 | 
						|
              {'codec', 'sample_rate', 'channels', 'duration', 'language'}
 | 
						|
              where 'duration' is set to container duration (for consistency).
 | 
						|
    """
 | 
						|
    probe = ffmpeg.probe(source_video)
 | 
						|
    audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio']
 | 
						|
    container_duration = float(probe['format'].get('duration', 0.0))
 | 
						|
 | 
						|
    if not audio_streams:
 | 
						|
        if query_only: return 0
 | 
						|
        if verbose: print(f"No audio track found in {source_video}")
 | 
						|
        return [], []
 | 
						|
 | 
						|
    if query_only:
 | 
						|
        return len(audio_streams)
 | 
						|
 | 
						|
    if verbose:
 | 
						|
        print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s")
 | 
						|
 | 
						|
    file_paths = []
 | 
						|
    metadata = []
 | 
						|
 | 
						|
    for i, stream in enumerate(audio_streams):
 | 
						|
        fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_')
 | 
						|
        os.close(fd)
 | 
						|
 | 
						|
        file_paths.append(temp_path)
 | 
						|
        metadata.append({
 | 
						|
            'codec': stream.get('codec_name'),
 | 
						|
            'sample_rate': int(stream.get('sample_rate', 0)),
 | 
						|
            'channels': int(stream.get('channels', 0)),
 | 
						|
            'duration': container_duration,
 | 
						|
            'language': stream.get('tags', {}).get('language', None)
 | 
						|
        })
 | 
						|
 | 
						|
        ffmpeg.input(source_video).output(
 | 
						|
            temp_path,
 | 
						|
            **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'}
 | 
						|
        ).overwrite_output().run(quiet=not verbose)
 | 
						|
 | 
						|
    return file_paths, metadata
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def combine_and_concatenate_video_with_audio_tracks(
 | 
						|
    save_path_tmp, video_path,
 | 
						|
    source_audio_tracks, new_audio_tracks,
 | 
						|
    source_audio_duration, audio_sampling_rate,
 | 
						|
    new_audio_from_start=False,
 | 
						|
    source_audio_metadata=None,
 | 
						|
    audio_bitrate='128k',
 | 
						|
    audio_codec='aac',
 | 
						|
    verbose = False
 | 
						|
):
 | 
						|
    inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1
 | 
						|
    metadata_args = []
 | 
						|
    sources = source_audio_tracks or []
 | 
						|
    news = new_audio_tracks or []
 | 
						|
 | 
						|
    duplicate_source = len(sources) == 1 and len(news) > 1
 | 
						|
    N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1
 | 
						|
 | 
						|
    for i in range(N):
 | 
						|
        s = (sources[i] if i < len(sources)
 | 
						|
             else sources[0] if duplicate_source else None)
 | 
						|
        n = news[i] if len(news) == N else (news[0] if news else None)
 | 
						|
 | 
						|
        if source_audio_duration == 0:
 | 
						|
            if n:
 | 
						|
                inputs += ['-i', n]
 | 
						|
                filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}]')
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}]')
 | 
						|
        else:
 | 
						|
            if s:
 | 
						|
                inputs += ['-i', s]
 | 
						|
                meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {}
 | 
						|
                needs_filter = (
 | 
						|
                    meta.get('codec') != audio_codec or
 | 
						|
                    meta.get('sample_rate') != audio_sampling_rate or
 | 
						|
                    meta.get('channels') != 1 or
 | 
						|
                    meta.get('duration', 0) < source_audio_duration
 | 
						|
                )
 | 
						|
                if needs_filter:
 | 
						|
                    filters.append(
 | 
						|
                        f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,'
 | 
						|
                        f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
                else:
 | 
						|
                    filters.append(
 | 
						|
                        f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
                if lang := meta.get('language'):
 | 
						|
                    metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}']
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(
 | 
						|
                    f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}]')
 | 
						|
 | 
						|
            if n:
 | 
						|
                inputs += ['-i', n]
 | 
						|
                start = '0' if new_audio_from_start else source_audio_duration
 | 
						|
                filters.append(
 | 
						|
                    f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,'
 | 
						|
                    f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}]')
 | 
						|
                filters.append(f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}]')
 | 
						|
                idx += 1
 | 
						|
            else:
 | 
						|
                filters.append(f'[s{i}]apad=pad_dur=100[aout{i}]')
 | 
						|
 | 
						|
        maps += ['-map', f'[aout{i}]']
 | 
						|
 | 
						|
    cmd = ['ffmpeg', '-y', *inputs,
 | 
						|
           '-filter_complex', ';'.join(filters),  # ✅ Only change made
 | 
						|
           *maps, *metadata_args,
 | 
						|
           '-c:v', 'copy',
 | 
						|
           '-c:a', audio_codec,
 | 
						|
           '-b:a', audio_bitrate,
 | 
						|
           '-ar', str(audio_sampling_rate),
 | 
						|
           '-ac', '1',
 | 
						|
           '-shortest', save_path_tmp]
 | 
						|
 | 
						|
    if verbose:
 | 
						|
        print(f"ffmpeg command: {cmd}")
 | 
						|
    try:
 | 
						|
        subprocess.run(cmd, check=True, capture_output=True, text=True)
 | 
						|
    except subprocess.CalledProcessError as e:
 | 
						|
        raise Exception(f"FFmpeg error: {e.stderr}")
 | 
						|
 | 
						|
 | 
						|
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video,
 | 
						|
                                     audio_metadata=None, verbose=False):
 | 
						|
    if not audio_tracks:
 | 
						|
        if verbose: print("No audio tracks to combine."); return False
 | 
						|
 | 
						|
    dur = float(next(s for s in ffmpeg.probe(target_video)['streams']
 | 
						|
                     if s['codec_type'] == 'video')['duration'])
 | 
						|
    if verbose: print(f"Video duration: {dur:.3f}s")
 | 
						|
 | 
						|
    cmd = ['ffmpeg', '-y', '-i', target_video]
 | 
						|
    for path in audio_tracks:
 | 
						|
        cmd += ['-i', path]
 | 
						|
 | 
						|
    cmd += ['-map', '0:v']
 | 
						|
    for i in range(len(audio_tracks)):
 | 
						|
        cmd += ['-map', f'{i+1}:a']
 | 
						|
 | 
						|
    for i, meta in enumerate(audio_metadata or []):
 | 
						|
        if (lang := meta.get('language')):
 | 
						|
            cmd += ['-metadata:s:a:' + str(i), f'language={lang}']
 | 
						|
 | 
						|
    cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video]
 | 
						|
 | 
						|
    result = subprocess.run(cmd, capture_output=not verbose, text=True)
 | 
						|
    if result.returncode != 0:
 | 
						|
        raise Exception(f"FFmpeg error:\n{result.stderr}")
 | 
						|
    if verbose:
 | 
						|
        print(f"Created {output_video} with {len(audio_tracks)} audio track(s)")
 | 
						|
    return True
 | 
						|
 | 
						|
 | 
						|
def cleanup_temp_audio_files(audio_tracks, verbose=False):
 | 
						|
    """
 | 
						|
    Clean up temporary audio files.
 | 
						|
    
 | 
						|
    Args:
 | 
						|
        audio_tracks: List of audio file paths to delete
 | 
						|
        verbose: Enable verbose output (default: False)
 | 
						|
        
 | 
						|
    Returns:
 | 
						|
        Number of files successfully deleted
 | 
						|
    """
 | 
						|
    deleted_count = 0
 | 
						|
    
 | 
						|
    for audio_path in audio_tracks:
 | 
						|
        try:
 | 
						|
            if os.path.exists(audio_path):
 | 
						|
                os.unlink(audio_path)
 | 
						|
                deleted_count += 1
 | 
						|
                if verbose:
 | 
						|
                    print(f"Cleaned up {audio_path}")
 | 
						|
        except PermissionError:
 | 
						|
            print(f"Warning: Could not delete {audio_path} (file may be in use)")
 | 
						|
        except Exception as e:
 | 
						|
            print(f"Warning: Error deleting {audio_path}: {e}")
 | 
						|
    
 | 
						|
    if verbose and deleted_count > 0:
 | 
						|
        print(f"Successfully deleted {deleted_count} temporary audio file(s)")
 | 
						|
    
 | 
						|
    return deleted_count
 | 
						|
 | 
						|
 | 
						|
def save_video(tensor,
 | 
						|
                save_file=None,
 | 
						|
                fps=30,
 | 
						|
                codec_type='libx264_8',
 | 
						|
                container='mp4',
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                retry=5):
 | 
						|
    """Save tensor as video with configurable codec and container options."""
 | 
						|
        
 | 
						|
    suffix = f'.{container}'
 | 
						|
    cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
 | 
						|
    if not cache_file.endswith(suffix):
 | 
						|
        cache_file = osp.splitext(cache_file)[0] + suffix
 | 
						|
    
 | 
						|
    # Configure codec parameters
 | 
						|
    codec_params = _get_codec_params(codec_type, container)
 | 
						|
    
 | 
						|
    # Process and save
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            if torch.is_tensor(tensor):
 | 
						|
                # Preprocess tensor
 | 
						|
                tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
                tensor = torch.stack([
 | 
						|
                    torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
						|
                    for u in tensor.unbind(2)
 | 
						|
                ], dim=1).permute(1, 2, 3, 0)
 | 
						|
                tensor = (tensor * 255).type(torch.uint8).cpu()
 | 
						|
                arrays = tensor.numpy()
 | 
						|
            else:
 | 
						|
                arrays = tensor
 | 
						|
 | 
						|
            # Write video (silence ffmpeg logs)
 | 
						|
            writer = imageio.get_writer(cache_file, fps=fps, ffmpeg_log_level='error', **codec_params)
 | 
						|
            for frame in arrays:
 | 
						|
                writer.append_data(frame)
 | 
						|
        
 | 
						|
            writer.close()
 | 
						|
            return cache_file
 | 
						|
            
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            print(f"error saving {save_file}: {e}")
 | 
						|
 | 
						|
 | 
						|
def _get_codec_params(codec_type, container):
 | 
						|
    """Get codec parameters based on codec type and container."""
 | 
						|
    if codec_type == 'libx264_8':
 | 
						|
        return {'codec': 'libx264', 'quality': 8, 'pixelformat': 'yuv420p'}
 | 
						|
    elif codec_type == 'libx264_10':
 | 
						|
        return {'codec': 'libx264', 'quality': 10, 'pixelformat': 'yuv420p'}
 | 
						|
    elif codec_type == 'libx265_28':
 | 
						|
        return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '28', '-x265-params', 'log-level=none','-hide_banner', '-nostats']}
 | 
						|
    elif codec_type == 'libx265_8':
 | 
						|
        return {'codec': 'libx265', 'pixelformat': 'yuv420p', 'output_params': ['-crf', '8', '-x265-params', 'log-level=none','-hide_banner', '-nostats']}
 | 
						|
    elif codec_type == 'libx264_lossless':
 | 
						|
        if container == 'mkv':
 | 
						|
            return {'codec': 'ffv1', 'pixelformat': 'rgb24'}
 | 
						|
        else:  # mp4
 | 
						|
            return {'codec': 'libx264', 'output_params': ['-crf', '0'], 'pixelformat': 'yuv444p'}
 | 
						|
    else:  # libx264
 | 
						|
        return {'codec': 'libx264', 'pixelformat': 'yuv420p'}
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def save_image(tensor,
 | 
						|
                save_file,
 | 
						|
                nrow=8,
 | 
						|
                normalize=True,
 | 
						|
                value_range=(-1, 1),
 | 
						|
                quality='jpeg_95',  # 'jpeg_95', 'jpeg_85', 'jpeg_70', 'jpeg_50', 'webp_95', 'webp_85', 'webp_70', 'webp_50', 'png', 'webp_lossless'
 | 
						|
                retry=5):
 | 
						|
    """Save tensor as image with configurable format and quality."""
 | 
						|
    
 | 
						|
    # Get format and quality settings
 | 
						|
    format_info = _get_format_info(quality)
 | 
						|
    
 | 
						|
    # Rename file extension to match requested format
 | 
						|
    save_file = osp.splitext(save_file)[0] + format_info['ext']
 | 
						|
    
 | 
						|
    # Save image
 | 
						|
    error = None
 | 
						|
    for _ in range(retry):
 | 
						|
        try:
 | 
						|
            tensor = tensor.clamp(min(value_range), max(value_range))
 | 
						|
            
 | 
						|
            if format_info['use_pil']:
 | 
						|
                # Use PIL for WebP and advanced options
 | 
						|
                grid = torchvision.utils.make_grid(tensor, nrow=nrow, normalize=normalize, value_range=value_range)
 | 
						|
                # Convert to PIL Image
 | 
						|
                grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
 | 
						|
                img = Image.fromarray(grid)
 | 
						|
                img.save(save_file, **format_info['params'])
 | 
						|
            else:
 | 
						|
                # Use torchvision for JPEG and PNG
 | 
						|
                torchvision.utils.save_image(
 | 
						|
                    tensor, save_file, nrow=nrow, normalize=normalize, 
 | 
						|
                    value_range=value_range, **format_info['params']
 | 
						|
                )
 | 
						|
            break
 | 
						|
        except Exception as e:
 | 
						|
            error = e
 | 
						|
            continue
 | 
						|
    else:
 | 
						|
        print(f'cache_image failed, error: {error}', flush=True)
 | 
						|
    
 | 
						|
    return save_file
 | 
						|
 | 
						|
 | 
						|
def _get_format_info(quality):
 | 
						|
    """Get format extension and parameters."""
 | 
						|
    formats = {
 | 
						|
        # JPEG with torchvision (works)
 | 
						|
        'jpeg_95': {'ext': '.jpg', 'params': {'quality': 95}, 'use_pil': False},
 | 
						|
        'jpeg_85': {'ext': '.jpg', 'params': {'quality': 85}, 'use_pil': False},
 | 
						|
        'jpeg_70': {'ext': '.jpg', 'params': {'quality': 70}, 'use_pil': False},
 | 
						|
        'jpeg_50': {'ext': '.jpg', 'params': {'quality': 50}, 'use_pil': False},
 | 
						|
        
 | 
						|
        # PNG with torchvision
 | 
						|
        'png': {'ext': '.png', 'params': {}, 'use_pil': False},
 | 
						|
        
 | 
						|
        # WebP with PIL (for quality control)
 | 
						|
        'webp_95': {'ext': '.webp', 'params': {'quality': 95}, 'use_pil': True},
 | 
						|
        'webp_85': {'ext': '.webp', 'params': {'quality': 85}, 'use_pil': True},
 | 
						|
        'webp_70': {'ext': '.webp', 'params': {'quality': 70}, 'use_pil': True},
 | 
						|
        'webp_50': {'ext': '.webp', 'params': {'quality': 50}, 'use_pil': True},
 | 
						|
        'webp_lossless': {'ext': '.webp', 'params': {'lossless': True}, 'use_pil': True},
 | 
						|
    }
 | 
						|
    return formats.get(quality, formats['jpeg_95'])
 | 
						|
 | 
						|
 | 
						|
from PIL import Image, PngImagePlugin
 | 
						|
 | 
						|
def _enc_uc(s):
 | 
						|
    try: return b"ASCII\0\0\0" + s.encode("ascii")
 | 
						|
    except UnicodeEncodeError: return b"UNICODE\0" + s.encode("utf-16le")
 | 
						|
 | 
						|
def _dec_uc(b):
 | 
						|
    if not isinstance(b, (bytes, bytearray)):
 | 
						|
        try: b = bytes(b)
 | 
						|
        except Exception: return None
 | 
						|
    if b.startswith(b"ASCII\0\0\0"): return b[8:].decode("ascii", "ignore")
 | 
						|
    if b.startswith(b"UNICODE\0"):   return b[8:].decode("utf-16le", "ignore")
 | 
						|
    return b.decode("utf-8", "ignore")
 | 
						|
 | 
						|
def save_image_metadata(image_path, metadata_dict, **save_kwargs):
 | 
						|
    try:
 | 
						|
        j = json.dumps(metadata_dict, ensure_ascii=False)
 | 
						|
        ext = os.path.splitext(image_path)[1].lower()
 | 
						|
        with Image.open(image_path) as im:
 | 
						|
            if ext == ".png":
 | 
						|
                pi = PngImagePlugin.PngInfo(); pi.add_text("comment", j)
 | 
						|
                im.save(image_path, pnginfo=pi, **save_kwargs); return True
 | 
						|
            if ext in (".jpg", ".jpeg"):
 | 
						|
                im.save(image_path, comment=j.encode("utf-8"), **save_kwargs); return True
 | 
						|
            if ext == ".webp":
 | 
						|
                import piexif
 | 
						|
                exif = {"0th":{}, "Exif":{piexif.ExifIFD.UserComment:_enc_uc(j)}, "GPS":{}, "1st":{}, "thumbnail":None}
 | 
						|
                im.save(image_path, format="WEBP", exif=piexif.dump(exif), **save_kwargs); return True
 | 
						|
            raise ValueError("Unsupported format")
 | 
						|
    except Exception as e:
 | 
						|
        print(f"Error saving metadata: {e}"); return False
 | 
						|
 | 
						|
def read_image_metadata(image_path):
 | 
						|
    try:
 | 
						|
        ext = os.path.splitext(image_path)[1].lower()
 | 
						|
        with Image.open(image_path) as im:
 | 
						|
            if ext == ".png":
 | 
						|
                val = (getattr(im, "text", {}) or {}).get("comment") or im.info.get("comment")
 | 
						|
                return json.loads(val) if val else None
 | 
						|
            if ext in (".jpg", ".jpeg"):
 | 
						|
                val = im.info.get("comment")
 | 
						|
                if isinstance(val, (bytes, bytearray)): val = val.decode("utf-8", "ignore")
 | 
						|
                if val:
 | 
						|
                    try: return json.loads(val)
 | 
						|
                    except Exception: pass
 | 
						|
                exif = getattr(im, "getexif", lambda: None)()
 | 
						|
                if exif:
 | 
						|
                    uc = exif.get(37510)  # UserComment
 | 
						|
                    s = _dec_uc(uc) if uc else None
 | 
						|
                    if s:
 | 
						|
                        try: return json.loads(s)
 | 
						|
                        except Exception: pass
 | 
						|
                return None
 | 
						|
            if ext == ".webp":
 | 
						|
                exif_bytes = Image.open(image_path).info.get("exif")
 | 
						|
                if not exif_bytes: return None
 | 
						|
                import piexif
 | 
						|
                uc = piexif.load(exif_bytes).get("Exif", {}).get(piexif.ExifIFD.UserComment)
 | 
						|
                s = _dec_uc(uc) if uc else None
 | 
						|
                return json.loads(s) if s else None
 | 
						|
            return None
 | 
						|
    except Exception as e:
 | 
						|
        print(f"Error reading metadata: {e}"); return None |