mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Add mkv video export support and support for embedding source images when using i2v
This commit is contained in:
parent
a4d62fc1cd
commit
799c3a2e5e
4
.gitignore
vendored
4
.gitignore
vendored
@ -41,3 +41,7 @@ gradio_outputs/
|
||||
ckpts/
|
||||
loras/
|
||||
loras_i2v/
|
||||
|
||||
wgp_config.json
|
||||
|
||||
settings/
|
||||
|
||||
110
extract_source_images.py
Executable file
110
extract_source_images.py
Executable file
@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Extract Source Images from MKV Video Files
|
||||
|
||||
This utility extracts source images that were embedded as attachments in MKV video files
|
||||
generated by Wan2GP with source image embedding enabled.
|
||||
|
||||
Usage:
|
||||
python extract_source_images.py video.mkv [output_directory]
|
||||
|
||||
Examples:
|
||||
# Extract to same directory as video
|
||||
python extract_source_images.py generated_video.mkv
|
||||
|
||||
# Extract to specific directory
|
||||
python extract_source_images.py generated_video.mkv ./extracted_sources/
|
||||
|
||||
# Extract from multiple videos
|
||||
python extract_source_images.py *.mkv
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
from shared.utils.audio_video import extract_source_images
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract source images from MKV video files with embedded attachments",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'video_files',
|
||||
nargs='+',
|
||||
help='MKV video file(s) to extract source images from'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output-dir',
|
||||
help='Output directory for extracted images (default: same as video directory)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--verbose',
|
||||
action='store_true',
|
||||
help='Enable verbose output'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Expand glob patterns
|
||||
video_files = []
|
||||
for pattern in args.video_files:
|
||||
if '*' in pattern or '?' in pattern:
|
||||
video_files.extend(glob.glob(pattern))
|
||||
else:
|
||||
video_files.append(pattern)
|
||||
|
||||
if not video_files:
|
||||
print("No video files found matching the specified patterns.")
|
||||
return 1
|
||||
|
||||
total_extracted = 0
|
||||
|
||||
for video_file in video_files:
|
||||
if not os.path.exists(video_file):
|
||||
print(f"Warning: File not found: {video_file}")
|
||||
continue
|
||||
|
||||
if not video_file.lower().endswith('.mkv'):
|
||||
print(f"Warning: Skipping non-MKV file: {video_file}")
|
||||
continue
|
||||
|
||||
if args.verbose:
|
||||
print(f"\nProcessing: {video_file}")
|
||||
|
||||
# Determine output directory
|
||||
if args.output_dir:
|
||||
output_dir = args.output_dir
|
||||
else:
|
||||
# Create subdirectory next to video file
|
||||
video_dir = os.path.dirname(video_file) or '.'
|
||||
video_name = os.path.splitext(os.path.basename(video_file))[0]
|
||||
output_dir = os.path.join(video_dir, f"{video_name}_sources")
|
||||
|
||||
try:
|
||||
extracted_files = extract_source_images(video_file, output_dir)
|
||||
|
||||
if extracted_files:
|
||||
print(f"✓ Extracted {len(extracted_files)} source image(s) from {video_file}")
|
||||
if args.verbose:
|
||||
for img_file in extracted_files:
|
||||
print(f" → {img_file}")
|
||||
total_extracted += len(extracted_files)
|
||||
else:
|
||||
print(f"ℹ No source images found in {video_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error processing {video_file}: {e}")
|
||||
|
||||
print(f"\nTotal: Extracted {total_extracted} source image(s) from {len(video_files)} video file(s)")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -229,7 +229,8 @@ def save_video(tensor,
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1),
|
||||
retry=5):
|
||||
retry=5,
|
||||
source_images=None):
|
||||
"""Save tensor as video with configurable codec and container options."""
|
||||
|
||||
if torch.is_tensor(tensor) and len(tensor.shape) == 4:
|
||||
@ -265,6 +266,14 @@ def save_video(tensor,
|
||||
writer.append_data(frame)
|
||||
|
||||
writer.close()
|
||||
|
||||
# Embed source images if provided and container supports it
|
||||
if source_images and container == 'mkv':
|
||||
try:
|
||||
cache_file = embed_source_images_metadata(cache_file, source_images)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to embed source images: {e}")
|
||||
|
||||
return cache_file
|
||||
|
||||
except Exception as e:
|
||||
@ -272,6 +281,324 @@ def save_video(tensor,
|
||||
print(f"error saving {save_file}: {e}")
|
||||
|
||||
|
||||
def embed_source_images_metadata(video_path, source_images):
|
||||
"""
|
||||
Embed source images as attachments in MKV video files using FFmpeg.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the video file
|
||||
source_images (dict): Dictionary containing source images
|
||||
Expected keys: 'image_start', 'image_end', 'image_refs'
|
||||
Values should be PIL Images or file paths
|
||||
|
||||
Returns:
|
||||
str: Path to the video file with embedded attachments
|
||||
"""
|
||||
import tempfile
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
if not source_images:
|
||||
return video_path
|
||||
|
||||
# Create temporary directory for image files
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
attachment_files = []
|
||||
|
||||
# Process each source image type
|
||||
for img_type, img_data in source_images.items():
|
||||
if img_data is None:
|
||||
continue
|
||||
|
||||
# Handle different image input types
|
||||
if isinstance(img_data, list):
|
||||
# Multiple images (e.g., image_refs)
|
||||
for i, img in enumerate(img_data):
|
||||
if img is not None:
|
||||
img_path = _save_temp_image(img, temp_dir, f"{img_type}_{i}")
|
||||
if img_path:
|
||||
attachment_files.append((img_path, f"{img_type}_{i}.jpg"))
|
||||
else:
|
||||
# Single image
|
||||
img_path = _save_temp_image(img_data, temp_dir, img_type)
|
||||
if img_path:
|
||||
attachment_files.append((img_path, f"{img_type}.jpg"))
|
||||
|
||||
if not attachment_files:
|
||||
return video_path
|
||||
|
||||
# Build FFmpeg command
|
||||
ffmpeg_cmd = ['ffmpeg', '-y', '-i', video_path]
|
||||
|
||||
# Add attachment parameters
|
||||
for i, (file_path, filename) in enumerate(attachment_files):
|
||||
ffmpeg_cmd.extend(['-attach', file_path])
|
||||
ffmpeg_cmd.extend(['-metadata:s:t:' + str(i), f'mimetype=image/jpeg'])
|
||||
ffmpeg_cmd.extend(['-metadata:s:t:' + str(i), f'filename={filename}'])
|
||||
|
||||
# Output parameters
|
||||
ffmpeg_cmd.extend(['-c', 'copy']) # Copy streams without re-encoding
|
||||
|
||||
# Create output file
|
||||
output_path = video_path.replace('.mkv', '_with_sources.mkv')
|
||||
ffmpeg_cmd.append(output_path)
|
||||
|
||||
# Verify all attachment files exist before running FFmpeg
|
||||
for file_path, filename in attachment_files:
|
||||
if not os.path.exists(file_path):
|
||||
print(f"ERROR: Attachment file missing: {file_path}")
|
||||
return video_path
|
||||
|
||||
try:
|
||||
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, check=True)
|
||||
|
||||
# Verify output file was created
|
||||
if not os.path.exists(output_path):
|
||||
print(f"ERROR: FFmpeg completed but output file {output_path} was not created")
|
||||
return video_path
|
||||
|
||||
# Check output file size and streams before replacing
|
||||
output_size = os.path.getsize(output_path)
|
||||
|
||||
# Verify the output file actually has attachments
|
||||
try:
|
||||
import subprocess as sp
|
||||
probe_result = sp.run([
|
||||
'ffprobe', '-v', 'quiet', '-print_format', 'json',
|
||||
'-show_streams', output_path
|
||||
], capture_output=True, text=True)
|
||||
|
||||
if probe_result.returncode == 0:
|
||||
import json
|
||||
probe_data = json.loads(probe_result.stdout)
|
||||
streams = probe_data.get('streams', [])
|
||||
attachment_streams = [s for s in streams if s.get('disposition', {}).get('attached_pic') == 1]
|
||||
|
||||
if len(attachment_streams) == 0:
|
||||
print(f"WARNING: Output file has no attachment streams despite FFmpeg success!")
|
||||
|
||||
except Exception as probe_error:
|
||||
pass
|
||||
|
||||
# Replace original file with the one containing attachments
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# Backup original file first
|
||||
backup_path = video_path + ".backup"
|
||||
shutil.copy2(video_path, backup_path)
|
||||
|
||||
# Replace original with new file - use explicit error handling
|
||||
try:
|
||||
shutil.move(output_path, video_path)
|
||||
except Exception as move_error:
|
||||
print(f"ERROR: shutil.move() failed: {move_error}")
|
||||
# Restore backup and return
|
||||
if os.path.exists(backup_path):
|
||||
shutil.move(backup_path, video_path)
|
||||
return video_path
|
||||
|
||||
# Verify replacement actually worked by checking file exists and size
|
||||
if not os.path.exists(video_path):
|
||||
print(f"ERROR: File replacement failed - target file doesn't exist!")
|
||||
# Restore backup
|
||||
if os.path.exists(backup_path):
|
||||
shutil.move(backup_path, video_path)
|
||||
return video_path
|
||||
|
||||
final_size = os.path.getsize(video_path)
|
||||
|
||||
if final_size == output_size:
|
||||
# Remove backup
|
||||
os.remove(backup_path)
|
||||
else:
|
||||
print(f"ERROR: File replacement failed - size mismatch! Expected {output_size}, got {final_size}")
|
||||
# Restore backup
|
||||
if os.path.exists(backup_path):
|
||||
shutil.move(backup_path, video_path)
|
||||
return video_path
|
||||
|
||||
except Exception as move_error:
|
||||
print(f"ERROR: File replacement failed: {move_error}")
|
||||
# Try to restore backup if it exists
|
||||
backup_path = video_path + ".backup"
|
||||
if os.path.exists(backup_path):
|
||||
try:
|
||||
shutil.move(backup_path, video_path)
|
||||
except:
|
||||
pass
|
||||
return video_path
|
||||
|
||||
return video_path
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"FFmpeg error embedding source images: {e.stderr}")
|
||||
# Clean up temp file if it exists
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
return video_path
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
import shutil
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def _save_temp_image(img_data, temp_dir, name):
|
||||
"""
|
||||
Save image data to a temporary file.
|
||||
|
||||
Args:
|
||||
img_data: PIL Image, file path, or tensor
|
||||
temp_dir: Temporary directory path
|
||||
name: Base name for the file
|
||||
|
||||
Returns:
|
||||
str: Path to saved temporary file, or None if failed
|
||||
"""
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
try:
|
||||
temp_path = os.path.join(temp_dir, f"{name}.jpg")
|
||||
|
||||
if isinstance(img_data, str):
|
||||
# File path - copy the file
|
||||
if os.path.exists(img_data):
|
||||
import shutil
|
||||
shutil.copy2(img_data, temp_path)
|
||||
return temp_path
|
||||
elif hasattr(img_data, 'save'):
|
||||
# PIL Image
|
||||
img_data.save(temp_path, 'JPEG', quality=95)
|
||||
return temp_path
|
||||
elif torch.is_tensor(img_data):
|
||||
# Tensor - convert to PIL and save
|
||||
if img_data.dim() == 4:
|
||||
img_data = img_data.squeeze(0)
|
||||
if img_data.dim() == 3:
|
||||
# Convert from tensor to PIL
|
||||
if img_data.shape[0] == 3: # CHW format
|
||||
img_data = img_data.permute(1, 2, 0)
|
||||
# Normalize to 0-255 range
|
||||
if img_data.max() <= 1.0:
|
||||
img_data = (img_data * 255).clamp(0, 255)
|
||||
img_array = img_data.cpu().numpy().astype('uint8')
|
||||
img_pil = Image.fromarray(img_array)
|
||||
img_pil.save(temp_path, 'JPEG', quality=95)
|
||||
return temp_path
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception in _save_temp_image for {name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def extract_source_images(video_path, output_dir=None):
|
||||
"""
|
||||
Extract embedded source images from MKV video files.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to the MKV video file
|
||||
output_dir (str): Directory to save extracted images (optional)
|
||||
|
||||
Returns:
|
||||
list: List of extracted image file paths
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = os.path.dirname(video_path)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# First, probe the video to find attachment streams (attached pics)
|
||||
probe_cmd = [
|
||||
'ffprobe', '-v', 'quiet', '-print_format', 'json',
|
||||
'-show_streams', video_path
|
||||
]
|
||||
|
||||
result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True)
|
||||
import json as json_module
|
||||
probe_data = json_module.loads(result.stdout)
|
||||
|
||||
# Find attachment streams (attached pics)
|
||||
attachment_streams = []
|
||||
for i, stream in enumerate(probe_data.get('streams', [])):
|
||||
# Check for attachment streams in multiple ways:
|
||||
# 1. Traditional attached_pic flag
|
||||
# 2. Video streams with image-like metadata (filename, mimetype)
|
||||
# 3. MJPEG codec which is commonly used for embedded images
|
||||
is_attached_pic = stream.get('disposition', {}).get('attached_pic', 0) == 1
|
||||
|
||||
# Check for image metadata in video streams (our case after metadata embedding)
|
||||
tags = stream.get('tags', {})
|
||||
has_image_metadata = (
|
||||
'FILENAME' in tags and tags['FILENAME'].lower().endswith(('.jpg', '.jpeg', '.png')) or
|
||||
'filename' in tags and tags['filename'].lower().endswith(('.jpg', '.jpeg', '.png')) or
|
||||
'MIMETYPE' in tags and tags['MIMETYPE'].startswith('image/') or
|
||||
'mimetype' in tags and tags['mimetype'].startswith('image/')
|
||||
)
|
||||
|
||||
# Check for MJPEG codec (common for embedded images)
|
||||
is_mjpeg = stream.get('codec_name') == 'mjpeg'
|
||||
|
||||
if (stream.get('codec_type') == 'video' and
|
||||
(is_attached_pic or (has_image_metadata and is_mjpeg))):
|
||||
attachment_streams.append(i)
|
||||
|
||||
if not attachment_streams:
|
||||
print(f"No attachment streams found in {video_path}")
|
||||
return []
|
||||
|
||||
# Extract each attachment stream
|
||||
extracted_files = []
|
||||
for stream_idx in attachment_streams:
|
||||
# Get original filename from metadata if available
|
||||
stream_info = probe_data['streams'][stream_idx]
|
||||
tags = stream_info.get('tags', {})
|
||||
original_filename = (
|
||||
tags.get('filename') or
|
||||
tags.get('FILENAME') or
|
||||
f'attachment_{stream_idx}.png'
|
||||
)
|
||||
|
||||
# Clean filename for filesystem
|
||||
safe_filename = os.path.basename(original_filename)
|
||||
if not safe_filename.lower().endswith(('.jpg', '.jpeg', '.png')):
|
||||
safe_filename += '.png'
|
||||
|
||||
output_file = os.path.join(output_dir, safe_filename)
|
||||
|
||||
# Extract the attachment stream
|
||||
extract_cmd = [
|
||||
'ffmpeg', '-y', '-i', video_path,
|
||||
'-map', f'0:{stream_idx}', '-frames:v', '1',
|
||||
output_file
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(extract_cmd, capture_output=True, text=True, check=True)
|
||||
if os.path.exists(output_file):
|
||||
extracted_files.append(output_file)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to extract attachment {stream_idx}: {e.stderr}")
|
||||
|
||||
return extracted_files
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error extracting source images: {e.stderr}")
|
||||
return []
|
||||
|
||||
|
||||
def _get_codec_params(codec_type, container):
|
||||
"""Get codec parameters based on codec type and container."""
|
||||
if codec_type == 'libx264_8':
|
||||
|
||||
@ -34,7 +34,7 @@ def seed_everything(seed: int):
|
||||
|
||||
def has_video_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
return extension in [".mp4"]
|
||||
return extension in [".mp4", ".mkv"]
|
||||
|
||||
def has_image_file_extension(filename):
|
||||
extension = os.path.splitext(filename)[-1].lower()
|
||||
|
||||
210
wgp.py
210
wgp.py
@ -2338,6 +2338,8 @@ reload_needed = False
|
||||
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "outputs"))
|
||||
image_save_path = server_config.get("image_save_path", os.path.join(os.getcwd(), "outputs"))
|
||||
if not "video_output_codec" in server_config: server_config["video_output_codec"]= "libx264_8"
|
||||
if not "video_container" in server_config: server_config["video_container"]= "mp4"
|
||||
if not "embed_source_images" in server_config: server_config["embed_source_images"]= False
|
||||
if not "image_output_codec" in server_config: server_config["image_output_codec"]= "jpeg_95"
|
||||
|
||||
preload_model_policy = server_config.get("preload_model_policy", [])
|
||||
@ -2988,6 +2990,8 @@ def apply_changes( state,
|
||||
max_frames_multiplier_choice = 1,
|
||||
display_stats_choice = 0,
|
||||
video_output_codec_choice = None,
|
||||
video_container_choice = None,
|
||||
embed_source_images_choice = None,
|
||||
image_output_codec_choice = None,
|
||||
audio_output_codec_choice = None,
|
||||
last_resolution_choice = None,
|
||||
@ -3026,6 +3030,8 @@ def apply_changes( state,
|
||||
"max_frames_multiplier" : max_frames_multiplier_choice,
|
||||
"display_stats" : display_stats_choice,
|
||||
"video_output_codec" : video_output_codec_choice,
|
||||
"video_container" : video_container_choice,
|
||||
"embed_source_images" : embed_source_images_choice,
|
||||
"image_output_codec" : image_output_codec_choice,
|
||||
"audio_output_codec" : audio_output_codec_choice,
|
||||
"last_model_type" : state["model_type"],
|
||||
@ -3073,7 +3079,7 @@ def apply_changes( state,
|
||||
reset_prompt_enhancer()
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant",
|
||||
"notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats",
|
||||
"video_output_codec", "image_output_codec", "audio_output_codec"] for change in changes ):
|
||||
"video_output_codec", "video_container", "embed_source_images", "image_output_codec", "audio_output_codec"] for change in changes ):
|
||||
model_family = gr.Dropdown()
|
||||
model_choice = gr.Dropdown()
|
||||
else:
|
||||
@ -4194,7 +4200,7 @@ def edit_video(
|
||||
any_change = False
|
||||
if sample != None:
|
||||
video_path =get_available_filename(save_path, video_source, "_tmp") if any_mmaudio or has_already_audio else get_available_filename(save_path, video_source, "_post")
|
||||
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None))
|
||||
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4"))
|
||||
|
||||
if any_mmaudio or has_already_audio: tmp_path = video_path
|
||||
any_change = True
|
||||
@ -5361,7 +5367,11 @@ def generate_video(
|
||||
save_prompt = original_prompts[0]
|
||||
|
||||
from shared.utils.utils import truncate_for_filesystem
|
||||
extension = "jpg" if is_image else "mp4"
|
||||
if is_image:
|
||||
extension = "jpg"
|
||||
else:
|
||||
container = server_config.get("video_container", "mp4")
|
||||
extension = container
|
||||
|
||||
if os.name == 'nt':
|
||||
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(truncate_for_filesystem(save_prompt,50)).strip()}.{extension}"
|
||||
@ -5381,8 +5391,9 @@ def generate_video(
|
||||
video_path= new_image_path
|
||||
elif len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0 or output_new_audio_filepath is not None or any_mmaudio or output_new_audio_data is not None or audio_source is not None:
|
||||
video_path = os.path.join(save_path, file_name)
|
||||
save_path_tmp = video_path[:-4] + "_tmp.mp4"
|
||||
save_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type = server_config.get("video_output_codec", None))
|
||||
container = server_config.get("video_container", "mp4")
|
||||
save_path_tmp = video_path.rsplit('.', 1)[0] + f"_tmp.{container}"
|
||||
save_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type = server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4"))
|
||||
output_new_audio_temp_filepath = None
|
||||
new_audio_from_start = reset_control_aligment
|
||||
source_audio_duration = source_video_frames_count / fps
|
||||
@ -5409,7 +5420,17 @@ def generate_video(
|
||||
if output_new_audio_temp_filepath is not None: os.remove(output_new_audio_temp_filepath)
|
||||
|
||||
else:
|
||||
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None))
|
||||
# Prepare source images for embedding if enabled
|
||||
source_images = {}
|
||||
if server_config.get("embed_source_images", False) and server_config.get("video_container", "mp4") == "mkv":
|
||||
if image_start is not None:
|
||||
source_images["image_start"] = image_start
|
||||
if image_end is not None:
|
||||
source_images["image_end"] = image_end
|
||||
if image_refs is not None:
|
||||
source_images["image_refs"] = image_refs
|
||||
|
||||
save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4"), source_images=source_images if source_images else None)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
@ -5447,11 +5468,43 @@ def generate_video(
|
||||
elif metadata_choice == "metadata":
|
||||
if is_image:
|
||||
save_image_metadata(path, configs)
|
||||
else:
|
||||
elif path.endswith('.mp4'):
|
||||
from mutagen.mp4 import MP4
|
||||
file = MP4(path)
|
||||
file.tags['©cmt'] = [json.dumps(configs)]
|
||||
file.save()
|
||||
elif path.endswith('.mkv'):
|
||||
# For MKV files, embed metadata using FFmpeg
|
||||
try:
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
# Create temporary file with metadata
|
||||
temp_path = path.replace('.mkv', '_temp_with_metadata.mkv')
|
||||
|
||||
# Use FFmpeg to add metadata while preserving ALL streams (including attachments)
|
||||
ffmpeg_cmd = [
|
||||
'ffmpeg', '-y', '-i', path,
|
||||
'-metadata', f'comment={json.dumps(configs)}',
|
||||
'-map', '0', # Map all streams from input (including attachments)
|
||||
'-c', 'copy', # Copy streams without re-encoding
|
||||
temp_path
|
||||
]
|
||||
|
||||
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Replace original with metadata version
|
||||
import shutil
|
||||
shutil.move(temp_path, path)
|
||||
else:
|
||||
print(f"Warning: Failed to add metadata to MKV file: {result.stderr}")
|
||||
# Clean up temp file if it exists
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error adding metadata to MKV file {path}: {e}")
|
||||
if is_image:
|
||||
print(f"New image saved to Path: "+ path)
|
||||
else:
|
||||
@ -6349,7 +6402,7 @@ def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling
|
||||
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
|
||||
return gr.update(), gr.update(), gr.update()
|
||||
|
||||
if not file_list[choice].endswith(".mp4"):
|
||||
if not (file_list[choice].endswith(".mp4") or file_list[choice].endswith(".mkv")):
|
||||
gr.Info("Post processing is only available with Videos")
|
||||
return gr.update(), gr.update(), gr.update()
|
||||
overrides = {
|
||||
@ -6372,7 +6425,7 @@ def remux_audio(state, input_file_list, choice, PP_MMAudio_setting, PP_MMAudio_p
|
||||
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) :
|
||||
return gr.update(), gr.update(), gr.update()
|
||||
|
||||
if not file_list[choice].endswith(".mp4"):
|
||||
if not (file_list[choice].endswith(".mp4") or file_list[choice].endswith(".mkv")):
|
||||
gr.Info("Post processing is only available with Videos")
|
||||
return gr.update(), gr.update(), gr.update()
|
||||
overrides = {
|
||||
@ -6477,6 +6530,80 @@ def export_settings(state):
|
||||
return text_base64, sanitize_file_name(model_type + "_" + datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") + ".json")
|
||||
|
||||
|
||||
def extract_and_apply_source_images(file_path, state):
|
||||
"""
|
||||
Extract embedded source images from MKV files and apply them to the UI state.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the MKV video file
|
||||
state (dict): UI state dictionary
|
||||
|
||||
Returns:
|
||||
int: Number of source images extracted and applied
|
||||
"""
|
||||
if not file_path.endswith('.mkv'):
|
||||
return 0
|
||||
|
||||
try:
|
||||
from shared.utils.audio_video import extract_source_images
|
||||
import tempfile
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
# Create temporary directory for extracted images
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
extracted_files = extract_source_images(file_path, temp_dir)
|
||||
|
||||
if not extracted_files:
|
||||
return 0
|
||||
|
||||
# Process extracted images and apply to state
|
||||
applied_count = 0
|
||||
|
||||
for img_path in extracted_files:
|
||||
img_name = os.path.basename(img_path).lower()
|
||||
|
||||
try:
|
||||
# Load the image
|
||||
pil_image = Image.open(img_path)
|
||||
|
||||
# Apply based on filename
|
||||
if 'image_start' in img_name:
|
||||
# Apply as start image
|
||||
current_settings = get_model_settings(state, state["model_type"]) or {}
|
||||
current_settings['image_start'] = [pil_image]
|
||||
set_model_settings(state, state["model_type"], current_settings)
|
||||
applied_count += 1
|
||||
|
||||
elif 'image_end' in img_name:
|
||||
# Apply as end image
|
||||
current_settings = get_model_settings(state, state["model_type"]) or {}
|
||||
current_settings['image_end'] = [pil_image]
|
||||
set_model_settings(state, state["model_type"], current_settings)
|
||||
applied_count += 1
|
||||
|
||||
elif 'image_refs' in img_name:
|
||||
# Apply as reference image
|
||||
current_settings = get_model_settings(state, state["model_type"]) or {}
|
||||
existing_refs = current_settings.get('image_refs', [])
|
||||
if not isinstance(existing_refs, list):
|
||||
existing_refs = []
|
||||
existing_refs.append(pil_image)
|
||||
current_settings['image_refs'] = existing_refs
|
||||
set_model_settings(state, state["model_type"], current_settings)
|
||||
applied_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing extracted image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
return applied_count
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting source images from {file_path}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def use_video_settings(state, input_file_list, choice):
|
||||
gen = get_gen_info(state)
|
||||
file_list, file_settings_list = get_file_list(state, input_file_list)
|
||||
@ -6495,11 +6622,21 @@ def use_video_settings(state, input_file_list, choice):
|
||||
defaults = get_default_settings(model_type) if defaults == None else defaults
|
||||
defaults.update(configs)
|
||||
prompt = configs.get("prompt", "")
|
||||
|
||||
# Extract and apply embedded source images from MKV files
|
||||
extracted_images = extract_and_apply_source_images(file_name, state)
|
||||
|
||||
set_model_settings(state, model_type, defaults)
|
||||
|
||||
# Update info message to include source image extraction
|
||||
if has_image_file_extension(file_name):
|
||||
gr.Info(f"Settings Loaded from Image with prompt '{prompt[:100]}'")
|
||||
else:
|
||||
gr.Info(f"Settings Loaded from Video with prompt '{prompt[:100]}'")
|
||||
info_msg = f"Settings Loaded from Video with prompt '{prompt[:100]}'"
|
||||
if extracted_images:
|
||||
info_msg += f" + {extracted_images} source image(s) extracted"
|
||||
gr.Info(info_msg)
|
||||
|
||||
if models_compatible:
|
||||
return gr.update(), gr.update(), str(time.time())
|
||||
else:
|
||||
@ -6527,6 +6664,32 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
|
||||
any_image_or_video = True
|
||||
except:
|
||||
pass
|
||||
elif file_path.endswith(".mkv"):
|
||||
# For MKV files, try to read metadata from attachments or use ffprobe
|
||||
try:
|
||||
import subprocess
|
||||
# Try to get metadata using ffprobe
|
||||
result = subprocess.run([
|
||||
'ffprobe', '-v', 'quiet', '-print_format', 'json',
|
||||
'-show_format', file_path
|
||||
], capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
import json as json_module
|
||||
probe_data = json_module.loads(result.stdout)
|
||||
format_tags = probe_data.get('format', {}).get('tags', {})
|
||||
|
||||
# Look for our metadata in various possible tag locations
|
||||
for tag_key in ['comment', 'COMMENT', 'description', 'DESCRIPTION']:
|
||||
if tag_key in format_tags:
|
||||
try:
|
||||
configs = json.loads(format_tags[tag_key])
|
||||
any_image_or_video = True
|
||||
break
|
||||
except:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
elif has_image_file_extension(file_path):
|
||||
try:
|
||||
configs = read_image_metadata(file_path)
|
||||
@ -6606,8 +6769,16 @@ def load_settings_from_file(state, file_path):
|
||||
prompt = configs.get("prompt", "")
|
||||
is_image = configs.get("is_image", False)
|
||||
|
||||
# Extract and apply embedded source images from MKV files
|
||||
extracted_images = 0
|
||||
if file_path.endswith('.mkv'):
|
||||
extracted_images = extract_and_apply_source_images(file_path, state)
|
||||
|
||||
if any_video_or_image_file:
|
||||
gr.Info(f"Settings Loaded from {'Image' if is_image else 'Video'} generated with prompt '{prompt[:100]}'")
|
||||
info_msg = f"Settings Loaded from {'Image' if is_image else 'Video'} generated with prompt '{prompt[:100]}'"
|
||||
if extracted_images > 0:
|
||||
info_msg += f" + {extracted_images} source image(s) extracted and applied"
|
||||
gr.Info(info_msg)
|
||||
else:
|
||||
gr.Info(f"Settings Loaded from Settings file with prompt '{prompt[:100]}'")
|
||||
|
||||
@ -8931,6 +9102,21 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
|
||||
label="Video Codec to use"
|
||||
)
|
||||
|
||||
video_container_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("MP4 (Universal Compatibility)", 'mp4'),
|
||||
("MKV (Advanced Features + Source Image Embedding)", 'mkv'),
|
||||
],
|
||||
value=server_config.get("video_container", "mp4"),
|
||||
label="Video Container Format"
|
||||
)
|
||||
|
||||
embed_source_images_choice = gr.Checkbox(
|
||||
value=server_config.get("embed_source_images", False),
|
||||
label="Embed Source Images in Video Files (requires MKV format)",
|
||||
info="Automatically embeds i2v source images as attachments in the video file for reference"
|
||||
)
|
||||
|
||||
image_output_codec_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("JPEG Quality 85", 'jpeg_85'),
|
||||
@ -9019,6 +9205,8 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
|
||||
max_frames_multiplier_choice,
|
||||
display_stats_choice,
|
||||
video_output_codec_choice,
|
||||
video_container_choice,
|
||||
embed_source_images_choice,
|
||||
image_output_codec_choice,
|
||||
audio_output_codec_choice,
|
||||
resolution,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user