From d2a9d5483de4c33d0682ab74a7ec68d387253e83 Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Mon, 4 Aug 2025 02:28:19 +0200 Subject: [PATCH] WanGP remuxed --- README.md | 22 + defaults/flux_krea.json | 16 + docs/FINETUNES.md | 9 +- docs/LORAS.md | 20 + flux/flux_main.py | 8 +- flux/sampling.py | 2 +- hyvideo/hunyuan.py | 35 +- ltx_video/ltxv.py | 11 +- postprocessing/mmaudio/data/av_utils.py | 2 + postprocessing/mmaudio/mmaudio.py | 12 +- preprocessing/matanyone/app.py | 50 +- requirements.txt | 3 +- wan/any2video.py | 88 ++-- wan/diffusion_forcing.py | 22 +- wan/fantasytalking/infer.py | 4 +- wan/fantasytalking/utils.py | 15 +- wan/modules/model.py | 6 +- wan/multitalk/multitalk.py | 15 +- wan/utils/loras_mutipliers.py | 91 ++++ wan/utils/utils.py | 286 ++++++---- wgp.py | 672 ++++++++++++++---------- 21 files changed, 875 insertions(+), 514 deletions(-) create mode 100644 defaults/flux_krea.json create mode 100644 wan/utils/loras_mutipliers.py diff --git a/README.md b/README.md index c10a9eb..0ba6bd6 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,28 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## πŸ”₯ Latest Updates : +### August 4 2025: WanGP v7.6 - Remuxed + +With this new version you won't have any excuse if there is no sound in your video. + +*Continue Video* now works with any video that has already some sound (hint: Multitalk ). + +Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. + +As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. + +For instance: +- first video part: use Multitalk with two people speaking +- second video part: you apply your own soundtrack which will gently follow the multitalk conversation +- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio + +To multiply the combinations I have also implemented *Continue Video* with the various image2video models. + +Also: +- End Frame support added for LTX Video models +- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides +- Flux Krea Dev support + ### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... diff --git a/defaults/flux_krea.json b/defaults/flux_krea.json new file mode 100644 index 0000000..3caba1a --- /dev/null +++ b/defaults/flux_krea.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Flux 1 Krea Dev 12B", + "architecture": "flux", + "description": "Cutting-edge output quality, with a focus on aesthetic photography..", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_quanto_bf16_int8.safetensors" + ], + "image_outputs": true, + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1280x720", + "batch_size": 1 +} \ No newline at end of file diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index 1c9ee6b..32bc7c6 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -55,9 +55,16 @@ For instance if one adds a module *vace_14B* on top of a model with architecture - *architecture* : architecture Id of the base model of the finetune (see previous section) - *description*: description of the finetune that will appear at the top - *URLs*: URLs of all the finetune versions (quantized / non quantized). WanGP will pick the version that is the closest to the user preferences. You will need to follow a naming convention to help WanGP identify the content of each version (see next section). Right now WanGP supports only 8 bits quantized model that have been quantized using **quanto**. WanGP offers a command switch to build easily such a quantized model (see below). *URLs* can contain also paths to local file to allow testing. +- *URLs2*: URLs of all the finetune versions (quantized / non quantized) of the weights used for the second phase of a model. For instance with Wan 2.2, the first phase contains the High Noise model weights and the second phase contains the Low Noise model weights. This feature can be used with other models than Wan 2.2 to combine different model weights during the same video generation. - *modules*: this a list of modules to be combined with the models referenced by the URLs. A module is a model extension that is merged with a model to expand its capabilities. Supported models so far are : *vace_14B* and *multitalk*. For instance the full Vace model is the fusion of a Wan text 2 video and the Vace module. - *preload_URLs* : URLs of files to download no matter what (used to load quantization maps for instance) --*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerator. For instance if you specified here the FusioniX Lora you will be able to reduce the number of generation steps to -*loras_multipliers* : a list of float numbers that defines the weight of each Lora mentioned above. +-*loras* : URLs of Loras that will applied before any other Lora specified by the user. These loras will be quite often Loras accelerators. For instance if you specify here the FusioniX Lora you will be able to reduce the number of generation steps to 10 +-*loras_multipliers* : a list of float numbers or strings that defines the weight of each Lora mentioned in *Loras*. The string syntax is used if you want your lora multiplier to change over the steps (please check the Loras doc) or if you want a multiplier to be applied on a specific High Noise phase or Low Noise phase of a Wan 2.2 model. For instance, here the multiplier will be only applied during the High Noise phase and for half of the steps of this phase the multiplier will be 1 and for the other half 1.1. +``` +"loras" : [ "my_lora.safetensors"], +"loras_multipliers" : [ "1,1.1;0"] +``` + - *auto_quantize*: if set to True and no quantized model URL is provided, WanGP will perform on the fly quantization if the user expects a quantized model -*visible* : by default assumed to be true. If set to false the model will no longer be visible. This can be useful if you create a finetune to override a default model and hide it. -*image_outputs* : turn any model that generates a video into a model that generates images. In fact it will adapt the user interface for image generation and ask the model to generate a video with a single frame. diff --git a/docs/LORAS.md b/docs/LORAS.md index 0b2d034..e20f7a3 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -63,6 +63,26 @@ For dynamic effects over generation steps, use comma-separated values: - First lora: 0.9 β†’ 0.8 β†’ 0.7 - Second lora: 1.2 β†’ 1.1 β†’ 1.0 +With models like Wan 2.2 that uses internally two diffusion models (*High noise* / *Low Noise*) you can specify which Loras you want to be applied for a specific phase by separating each phase with a ";". + +For instance, if you want to disable a lora for phase *High Noise* and enablesit only for phase *Low Noise*: +``` +0;1 +``` + +As usual, you can use any float for of multiplier and have a multiplier varries throughout one phase for one Lora: +``` +0.9,0.8;1.2,1.1,1 +``` +In this example multiplier 0.9 and 0.8 will be used during the *High Noise* phase and 1.2, 1.1 and 1 during the *Low Noise* phase. + +Here is another example for two loras: +``` +0.9,0.8;1.2,1.1,1 +0.5;0,0.7 +``` + +Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list) ## Lora Presets Lora Presets are combinations of loras with predefined multipliers and prompts. diff --git a/flux/flux_main.py b/flux/flux_main.py index 17b5405..303765a 100644 --- a/flux/flux_main.py +++ b/flux/flux_main.py @@ -58,12 +58,18 @@ class model_factory: # self.name= "flux-dev-kontext" # self.name= "flux-dev" # self.name= "flux-schnell" - self.model = load_flow_model(self.name, model_filename[0], torch_device) + source = model_def.get("source", None) + self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) self.vae = load_ae(self.name, device=torch_device) # offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "flux-dev.safetensors") + + if not source is None: + from wgp import save_model + save_model(self.model, model_type, dtype, None) + if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, None) diff --git a/flux/sampling.py b/flux/sampling.py index a8f9aae..5a15c5e 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -343,7 +343,7 @@ def denoise( updated_num_steps= len(timesteps) -1 if callback != None: - from wgp import update_loras_slists + from wan.utils.loras_mutipliers import update_loras_slists update_loras_slists(model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload diff --git a/hyvideo/hunyuan.py b/hyvideo/hunyuan.py index b3cc296..380ec77 100644 --- a/hyvideo/hunyuan.py +++ b/hyvideo/hunyuan.py @@ -21,7 +21,7 @@ from PIL import Image import numpy as np import torchvision.transforms as transforms import cv2 -from wan.utils.utils import resize_lanczos, calculate_new_dimensions +from wan.utils.utils import calculate_new_dimensions, convert_tensor_to_image from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask from transformers import WhisperModel from transformers import AutoFeatureExtractor @@ -720,7 +720,6 @@ class HunyuanVideoSampler(Inference): embedded_guidance_scale=6.0, batch_size=1, num_videos_per_prompt=1, - i2v_resolution="720p", image_start=None, enable_RIFLEx = False, i2v_condition_type: str = "token_replace", @@ -846,39 +845,13 @@ class HunyuanVideoSampler(Inference): denoise_strength = 0 ip_cfg_scale = 0 if i2v_mode: - if i2v_resolution == "720p": - bucket_hw_base_size = 960 - elif i2v_resolution == "540p": - bucket_hw_base_size = 720 - elif i2v_resolution == "360p": - bucket_hw_base_size = 480 - else: - raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") - - # semantic_images = [Image.open(i2v_image_path).convert('RGB')] - semantic_images = [image_start.convert('RGB')] # - origin_size = semantic_images[0].size - h, w = origin_size - h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - closest_size = (w, h) - # crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) - # aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list]) - # closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) - ref_image_transform = transforms.Compose([ - transforms.Resize(closest_size), - transforms.CenterCrop(closest_size), - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]) - ]) - - semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] - semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) - + semantic_images = convert_tensor_to_image(image_start) + semantic_image_pixel_values = image_start.unsqueeze(0).unsqueeze(2).to(self.device) with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W img_latents.mul_(self.pipeline.vae.config.scaling_factor) - target_height, target_width = closest_size + target_height, target_width = image_start.shape[1:] # ======================================================================== # Build Rope freqs diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index 0031d6d..34bae13 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -303,14 +303,15 @@ class LTXV: frame_width, frame_height = image_start.size if fit_into_canvas != None: height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) - conditioning_media_paths.append(image_start) + conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_start_frames.append(0) conditioning_control_frames.append(False) prefix_size = 1 - if image_end != None: - conditioning_media_paths.append(image_end) - conditioning_start_frames.append(frame_num-1) - conditioning_control_frames.append(False) + + if image_end != None: + conditioning_media_paths.append(image_end.unsqueeze(1)) + conditioning_start_frames.append(frame_num-1) + conditioning_control_frames.append(False) if input_frames!= None: conditioning_media_paths.append(input_frames) diff --git a/postprocessing/mmaudio/data/av_utils.py b/postprocessing/mmaudio/data/av_utils.py index 6fd0b1d..19776dc 100644 --- a/postprocessing/mmaudio/data/av_utils.py +++ b/postprocessing/mmaudio/data/av_utils.py @@ -132,11 +132,13 @@ import torch def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int): from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: temp_path = Path(f.name) temp_path_str= str(temp_path) import torchaudio torchaudio.save(temp_path_str, audio.unsqueeze(0) if audio.dim() == 1 else audio, sampling_rate) + combine_video_with_audio_tracks(video_path, [temp_path_str], output_path ) temp_path.unlink(missing_ok=True) diff --git a/postprocessing/mmaudio/mmaudio.py b/postprocessing/mmaudio/mmaudio.py index e153b09..c4f8ce6 100644 --- a/postprocessing/mmaudio/mmaudio.py +++ b/postprocessing/mmaudio/mmaudio.py @@ -76,7 +76,7 @@ def get_model(persistent_models = False, verboseLevel = 1) -> tuple[MMAudio, Fea @torch.inference_mode() def video_to_audio(video, prompt: str, negative_prompt: str, seed: int, num_steps: int, - cfg_strength: float, duration: float, video_save_path , persistent_models = False, verboseLevel = 1): + cfg_strength: float, duration: float, save_path , persistent_models = False, audio_file_only = False, verboseLevel = 1): global device @@ -110,11 +110,17 @@ def video_to_audio(video, prompt: str, negative_prompt: str, seed: int, num_step ) audio = audios.float().cpu()[0] - make_video(video, video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) + + if audio_file_only: + import torchaudio + torchaudio.save(save_path, audio.unsqueeze(0) if audio.dim() == 1 else audio, seq_cfg.sampling_rate) + else: + make_video(video, video_info, save_path, audio, sampling_rate=seq_cfg.sampling_rate) + offloadobj.unload_all() if not persistent_models: offloadobj.release() torch.cuda.empty_cache() gc.collect() - return video_save_path + return save_path diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index 6445c71..67813ee 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -69,6 +69,10 @@ def get_frames_from_image(image_input, image_state): [[0:nearest_frame], [nearest_frame:], nearest_frame] """ + if image_input is None: + gr.Info("Please select an Image file") + return [gr.update()] * 17 + user_name = time.time() frames = [image_input] * 2 # hardcode: mimic a video with 2 frames image_size = (frames[0].shape[0],frames[0].shape[1]) @@ -94,11 +98,12 @@ def get_frames_from_image(image_input, image_state): gr.update(visible=True, maximum=10, value=10), gr.update(visible=False, maximum=len(frames), value=len(frames)), \ gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True), gr.update(visible=True),\ - gr.update(visible=True), gr.update(visible=True), \ - gr.update(visible=True), gr.update(value="", visible=True), gr.update(visible=False), \ + gr.update(visible=True), gr.update(visible=False), \ + gr.update(visible=False), gr.update(value="", visible=False), gr.update(visible=False), \ gr.update(visible=False), gr.update(visible=True), \ gr.update(visible=True) + # extract frames from upload video def get_frames_from_video(video_input, video_state): """ @@ -108,7 +113,9 @@ def get_frames_from_video(video_input, video_state): Return [[0:nearest_frame], [nearest_frame:], nearest_frame] """ - + if video_input is None: + gr.Info("Please select a Video file") + return [gr.update()] * 18 while model == None: time.sleep(1) @@ -381,6 +388,7 @@ def save_video(frames, output_path, fps): def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) + if len(rows) == 0 or len(cols) == 0: return [] xmin = min(cols) xmax = max(cols) + 1 ymin = min(rows) @@ -449,13 +457,18 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si bbox_info = mask_to_xyxy_box(alpha_output) h = alpha_output.shape[0] w = alpha_output.shape[1] - bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] - bbox_info = ":".join(bbox_info) + if len(bbox_info) == 0: + bbox_info = "" + else: + bbox_info = [str(int(bbox_info[0]/ w * 100 )), str(int(bbox_info[1]/ h * 100 )), str(int(bbox_info[2]/ w * 100 )), str(int(bbox_info[3]/ h * 100 )) ] + bbox_info = ":".join(bbox_info) alpha_output = Image.fromarray(alpha_output) - return foreground_output, alpha_output, bbox_info, gr.update(visible=True), gr.update(visible=True) + # return gr.update(value=foreground_output, visible= True), gr.update(value=alpha_output, visible= True), gr.update(value=bbox_info, visible= True), gr.update(visible=True), gr.update(visible=True) + + return foreground_output, alpha_output, gr.update(visible = True), gr.update(visible = True), gr.update(value=bbox_info, visible= True), gr.update(visible=True), gr.update(visible=True) # video matting -def video_matting(video_state, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): +def video_matting(video_state,video_input, end_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) # if interactive_state["track_end_number"]: # following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] @@ -521,10 +534,21 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask file_name= video_state["video_name"] file_name = ".".join(file_name.split(".")[:-1]) - foreground_output = save_video(foreground, output_path="./mask_outputs/{}_fg.mp4".format(file_name), fps=fps) - # foreground_output = generate_video_from_frames(foreground, output_path="./results/{}_fg.mp4".format(video_state["video_name"]), fps=fps, audio_path=audio_path) # import video_input to name the output video + + from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files + source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) + output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" + output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" + if len(source_audio_tracks) == 0: + foreground_output = save_video(foreground, output_path=output_fg_path , fps=fps) + else: + foreground_output_tmp = save_video(foreground, output_path=output_fg_temp_path , fps=fps) + combine_video_with_audio_tracks(output_fg_temp_path, source_audio_tracks, output_fg_path, audio_metadata=audio_metadata) + cleanup_temp_audio_files(source_audio_tracks) + os.remove(foreground_output_tmp) + foreground_output = output_fg_path + alpha_output = save_video(alpha, output_path="./mask_outputs/{}_alpha.mp4".format(file_name), fps=fps) - # alpha_output = generate_video_from_frames(alpha, output_path="./results/{}_alpha.mp4".format(video_state["video_name"]), fps=fps, gray2rgb=True, audio_path=audio_path) # import video_input to name the output video return foreground_output, alpha_output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) @@ -912,7 +936,7 @@ def display(tabs, tab_state, vace_video_input, vace_image_input, vace_video_mask inputs=[], outputs=[foreground_video_output, alpha_video_output]).then( fn=video_matting, - inputs=[video_state, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], + inputs=[video_state, video_input, end_selection_slider, matting_type, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size], outputs=[foreground_video_output, alpha_video_output,foreground_video_output, alpha_video_output, export_to_vace_video_14B_btn, export_to_current_video_engine_btn] ) @@ -1053,7 +1077,7 @@ def display(tabs, tab_state, vace_video_input, vace_image_input, vace_video_mask foreground_image_output = gr.Image(type="pil", label="Foreground Output", visible=False, elem_classes="image") alpha_image_output = gr.Image(type="pil", label="Mask", visible=False, elem_classes="image") with gr.Row(equal_height=True): - bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", interactive= False) + bbox_info = gr.Text(label ="Mask BBox Info (Left:Top:Right:Bottom)", visible = False, interactive= False) with gr.Row(): # with gr.Row(): export_image_btn = gr.Button(value="Add to current Reference Images", visible=False, elem_classes="new_button") @@ -1116,7 +1140,7 @@ def display(tabs, tab_state, vace_video_input, vace_image_input, vace_video_mask matting_button.click( fn=image_matting, inputs=[image_state, interactive_state, mask_dropdown, erode_kernel_size, dilate_kernel_size, image_selection_slider], - outputs=[foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] + outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] ) diff --git a/requirements.txt b/requirements.txt index 968c3fc..b4eed4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ gradio==5.23.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 -mmgp==3.5.3 +mmgp==3.5.5 peft==0.15.0 mutagen pydantic==2.10.6 @@ -46,5 +46,6 @@ soundfile ffmpeg-python pyannote.audio pynvml +huggingface_hub[hf_xet] # num2words # spacy \ No newline at end of file diff --git a/wan/any2video.py b/wan/any2video.py index a5e9248..bea70c5 100644 --- a/wan/any2video.py +++ b/wan/any2video.py @@ -141,7 +141,8 @@ class WanAny2V: if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) - + if self.model2 is not None: + save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) self.sample_neg_prompt = config.sample_neg_prompt if self.model.config.get("vace_in_dim", None) != None: @@ -357,7 +358,7 @@ class WanAny2V: input_frames= None, input_masks = None, input_ref_images = None, - input_video=None, + input_video = None, image_start = None, image_end = None, denoising_strength = 1.0, @@ -395,6 +396,7 @@ class WanAny2V: conditioning_latents_size = 0, keep_frames_parsed = [], model_type = None, + model_mode = None, loras_slists = None, NAG_scale = 0, NAG_tau = 3.5, @@ -475,67 +477,63 @@ class WanAny2V: phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] multitalk = model_type in ["multitalk", "vace_multitalk_14B"] + recam = model_type in ["recam_1.3B"] ref_images_count = 0 trim_frames = 0 extended_overlapped_latents = None - # image2video lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 - if image_start != None: + # image2video + if model_type in ["i2v", "i2v_2_2", "fantasy", "multitalk", "flf2v_720p"]: any_end_frame = False - if input_frames != None: - _ , preframes_count, height, width = input_frames.shape + if image_start is None: + _ , preframes_count, height, width = input_video.shape lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - if hasattr(self, "clip"): - clip_context = self.clip.visual([input_frames[:, -1:]]) if model_type != "flf2v_720p" else self.clip.visual([input_frames[:, -1:], input_frames[:, -1:]]) + if hasattr(self, "clip"): + clip_image_size = self.clip.model.image_size + clip_image = resize_lanczos(input_video[:, -1], clip_image_size, clip_image_size)[:, None, :, :] + clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ]) + clip_image = None else: clip_context = None - input_frames = input_frames.to(device=self.device).to(dtype= self.VAE_dtype) - enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width), + input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) + enc = torch.concat( [input_video, torch.zeros( (3, frame_num-preframes_count, height, width), device=self.device, dtype= self.VAE_dtype)], dim = 1).to(self.device) - color_reference_frame = input_frames[:, -1:].clone() - input_frames = None + color_reference_frame = input_video[:, -1:].clone() + input_video = None else: preframes_count = 1 - image_start = TF.to_tensor(image_start) - any_end_frame = image_end != None + any_end_frame = image_end is not None add_frames_for_end_image = any_end_frame and model_type == "i2v" if any_end_frame: - image_end = TF.to_tensor(image_end) if add_frames_for_end_image: frame_num +=1 lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) trim_frames = 1 - h, w = image_start.shape[1:] + height, width = image_start.shape[1:] - h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - width, height = w, h - lat_h = round( - h // self.vae_stride[1] // + height // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( - w // self.vae_stride[2] // + width // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) - h = lat_h * self.vae_stride[1] - w = lat_w * self.vae_stride[2] - img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype - color_reference_frame = img_interpolated.clone() - if image_end!= None: - img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype + height = lat_h * self.vae_stride[1] + width = lat_w * self.vae_stride[2] + image_start_frame = image_start.unsqueeze(1).to(self.device) + color_reference_frame = image_start_frame.clone() + if image_end is not None: + img_end_frame = image_end.unsqueeze(1).to(self.device) if hasattr(self, "clip"): clip_image_size = self.clip.model.image_size image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype - if image_end!= None: - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - image_end = image_end.sub_(0.5).div_(0.5).to(self.device) #, self.dtype + if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end != None else image_start[:, None, :, :]]) + clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) else: clip_context = self.clip.visual([image_start[:, None, :, :]]) else: @@ -543,17 +541,17 @@ class WanAny2V: if any_end_frame: enc= torch.concat([ - img_interpolated, - torch.zeros( (3, frame_num-2, h, w), device=self.device, dtype= self.VAE_dtype), - img_interpolated2, + image_start_frame, + torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype), + img_end_frame, ], dim=1).to(self.device) else: enc= torch.concat([ - img_interpolated, - torch.zeros( (3, frame_num-1, h, w), device=self.device, dtype= self.VAE_dtype) + image_start_frame, + torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype) ], dim=1).to(self.device) - image_start = image_end = img_interpolated = img_interpolated2 = None + image_start = image_end = image_start_frame = img_end_frame = None msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: @@ -582,11 +580,12 @@ class WanAny2V: kwargs.update({'clip_fea': clip_context}) # Recam Master - if target_camera != None: + if recam: + # should be be in fact in input_frames since it is control video not a video to be extended + target_camera = model_mode width = input_video.shape[2] height = input_video.shape[1] input_video = input_video.to(dtype=self.dtype , device=self.device) - input_video = input_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) source_latents = self.vae.encode([input_video])[0] #.to(dtype=self.dtype, device=self.device) del input_video # Process target camera (recammaster) @@ -718,8 +717,13 @@ class WanAny2V: # init denoising updated_num_steps= len(timesteps) if callback != None: - from wan.utils.utils import update_loras_slists - update_loras_slists(self.model, loras_slists, updated_num_steps) + from wan.utils.loras_mutipliers import update_loras_slists + model_switch_step = updated_num_steps + for i, t in enumerate(timesteps): + if t <= switch_threshold: + model_switch_step = i + break + update_loras_slists(self.model, loras_slists, updated_num_steps, model_switch_step= model_switch_step) callback(-1, None, True, override_num_inference_steps = updated_num_steps) if sample_scheduler != None: diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index d477402..6960bda 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -19,7 +19,7 @@ from wan.utils.utils import calculate_new_dimensions from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.utils.utils import update_loras_slists +from wan.utils.loras_mutipliers import update_loras_slists class DTT2V: @@ -199,7 +199,6 @@ class DTT2V: self, input_prompt: Union[str, List[str]], n_prompt: Union[str, List[str]] = "", - image_start: PipelineImageInput = None, input_video = None, height: int = 480, width: int = 832, @@ -242,11 +241,6 @@ class DTT2V: if input_video != None: _ , _ , height, width = input_video.shape - elif image_start != None: - image_start = image_start - frame_width, frame_height = image_start.size - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas) - image_start = np.array(image_start.resize((width, height))).transpose(2, 0, 1) latent_length = (frame_num - 1) // 4 + 1 @@ -276,18 +270,8 @@ class DTT2V: output_video = input_video - if image_start is not None or output_video is not None: # i !=0 - if output_video is not None: - prefix_video = output_video.to(self.device) - else: - causal_block_size = 1 - causal_attention = False - ar_step = 0 - prefix_video = image_start - prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) - if prefix_video.dtype == torch.uint8: - prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 - prefix_video = prefix_video.to(self.device) + if output_video is not None: # i !=0 + prefix_video = output_video.to(self.device) prefix_video = self.vae.encode(prefix_video.unsqueeze(0))[0] # [(c, f, h, w)] predix_video_latent_length = prefix_video.shape[1] truncate_len = predix_video_latent_length % causal_block_size diff --git a/wan/fantasytalking/infer.py b/wan/fantasytalking/infer.py index 80d1945..d96bea0 100644 --- a/wan/fantasytalking/infer.py +++ b/wan/fantasytalking/infer.py @@ -6,7 +6,7 @@ from .model import FantasyTalkingAudioConditionModel from .utils import get_audio_features import gc, torch -def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): +def parse_audio(audio_path, start_frame, num_frames, fps = 23, device = "cuda"): fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) from mmgp import offload from accelerate import init_empty_weights @@ -24,7 +24,7 @@ def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"): wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False) wav2vec.to(device) proj_model.to(device) - audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames ) + audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, start_frame, num_frames) audio_proj_fea = proj_model(audio_wav2vec_fea) pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames ) diff --git a/wan/fantasytalking/utils.py b/wan/fantasytalking/utils.py index e044934..51f6678 100644 --- a/wan/fantasytalking/utils.py +++ b/wan/fantasytalking/utils.py @@ -26,13 +26,18 @@ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None): writer.close() -def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames): +def get_audio_features(wav2vec, audio_processor, audio_path, fps, start_frame, num_frames): sr = 16000 - audio_input, sample_rate = librosa.load(audio_path, sr=sr) # ι‡‡ζ ·ηŽ‡δΈΊ 16kHz + audio_input, sample_rate = librosa.load(audio_path, sr=sr) # ι‡‡ζ ·ηŽ‡δΈΊ 16kHz start_time = 0 + if start_frame < 0: + pad = int(abs(start_frame)/ fps * sr) + audio_input = np.concatenate([np.zeros(pad), audio_input]) + end_frame = num_frames + else: + end_frame = start_frame + num_frames - start_time = 0 - # end_time = (0 + (num_frames - 1) * 1) / fps - end_time = num_frames / fps + start_time = start_frame / fps + end_time = end_frame / fps start_sample = int(start_time * sr) end_sample = int(end_time * sr) diff --git a/wan/modules/model.py b/wan/modules/model.py index b67478e..b9da473 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -762,7 +762,11 @@ class WanModel(ModelMixin, ConfigMixin): offload.shared_state["_chipmunk_layers"] = None def preprocess_loras(self, model_type, sd): - + new_sd = {} + for k,v in sd.items(): + if not k.endswith(".modulation.diff"): + new_sd[ k] = v + sd = new_sd first = next(iter(sd), None) if first == None: return sd diff --git a/wan/multitalk/multitalk.py b/wan/multitalk/multitalk.py index 038efdf..56ba16b 100644 --- a/wan/multitalk/multitalk.py +++ b/wan/multitalk/multitalk.py @@ -74,7 +74,7 @@ def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): return human_speech_array -def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0): +def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0, pad = 0): if not (left_path==None or right_path==None): human_speech_array1 = audio_prepare_single(left_path, duration = duration) human_speech_array2 = audio_prepare_single(right_path, duration = duration) @@ -91,7 +91,13 @@ def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=1 elif audio_type=='add': new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]]) + + #dont include the padding on the summed audio which is used to build the output audio track sum_human_speechs = new_human_speech1 + new_human_speech2 + if pad > 0: + new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1]) + new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2]) + return new_human_speech1, new_human_speech2, sum_human_speechs def process_tts_single(text, save_dir, voice1): @@ -167,14 +173,13 @@ def process_tts_multi(text, save_dir, voice1, voice2): return s1, s2, save_path_sum -def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000): +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") - - new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps) + pad = int(padded_frames_for_embeddings/ fps * sr) + new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad) audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - full_audio_embs = [] if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) diff --git a/wan/utils/loras_mutipliers.py b/wan/utils/loras_mutipliers.py new file mode 100644 index 0000000..584d406 --- /dev/null +++ b/wan/utils/loras_mutipliers.py @@ -0,0 +1,91 @@ +def preparse_loras_multipliers(loras_multipliers): + if isinstance(loras_multipliers, list): + return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers] + + loras_multipliers = loras_multipliers.strip(" \r\n") + loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") + loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] + loras_multipliers = " ".join(loras_mult_choices_list) + return loras_multipliers.split(" ") + +def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step ): + def expand_one(slist, num_inference_steps): + if not isinstance(slist, list): slist = [slist] + new_slist= [] + if num_inference_steps <=0: + return new_slist + inc = len(slist) / num_inference_steps + pos = 0 + for i in range(num_inference_steps): + new_slist.append(slist[ int(pos)]) + pos += inc + return new_slist + + phase1 = slists_dict["phase1"][mult_no] + phase2 = slists_dict["phase2"][mult_no] + if isinstance(phase1, float) and isinstance(phase2, float) and phase1 == phase2: + return phase1 + return expand_one(phase1, model_switch_step) + expand_one(phase2, num_inference_steps - model_switch_step) + +def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, max_phases = 2, model_switch_step = None): + if model_switch_step is None: + model_switch_step = num_inference_steps + def is_float(element: any) -> bool: + if element is None: + return False + try: + float(element) + return True + except ValueError: + return False + loras_list_mult_choices_nums = [] + slists_dict = { "model_switch_step": model_switch_step} + slists_dict["phase1"] = phase1 = [1.] * nb_loras + slists_dict["phase2"] = phase2 = [1.] * nb_loras + + if isinstance(loras_multipliers, list) or len(loras_multipliers) > 0: + list_mult_choices_list = preparse_loras_multipliers(loras_multipliers) + for i, mult in enumerate(list_mult_choices_list): + current_phase = phase1 + if isinstance(mult, str): + mult = mult.strip() + phase_mult = mult.split(";") + shared_phases = len(phase_mult) <=1 + if len(phase_mult) > max_phases: + return "", "", f"Loras can not be defined for more than {max_phases} Denoising phases for this model" + for phase_no, mult in enumerate(phase_mult): + if phase_no > 0: current_phase = phase2 + if "," in mult: + multlist = mult.split(",") + slist = [] + for smult in multlist: + if not is_float(smult): + return "", "", f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid" + slist.append(float(smult)) + else: + if not is_float(mult): + return "", "", f"Lora Multiplier no {i+1} ({mult}) is invalid" + slist = float(mult) + if shared_phases: + phase1[i] = phase2[i] = slist + else: + current_phase[i] = slist + else: + phase1[i] = phase2[i] = float(mult) + + if merge_slist is not None: + slists_dict["phase1"] = phase1 = merge_slist["phase1"] + phase1 + slists_dict["phase2"] = phase2 = merge_slist["phase2"] + phase2 + + loras_list_mult_choices_nums = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step ) for i in range(len(phase1)) ] + loras_list_mult_choices_nums = [ slist[0] if isinstance(slist, list) else slist for slist in loras_list_mult_choices_nums ] + + return loras_list_mult_choices_nums, slists_dict, "" + +def update_loras_slists(trans, slists_dict, num_inference_steps, model_switch_step = None ): + from mmgp import offload + sz = len(slists_dict["phase1"]) + slists = [ expand_slist(slists_dict, i, num_inference_steps, model_switch_step ) for i in range(sz) ] + nos = [str(l) for l in range(sz)] + offload.activate_loras(trans, nos, slists ) + diff --git a/wan/utils/utils.py b/wan/utils/utils.py index d2afe9d..762d742 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -18,6 +18,8 @@ import random import ffmpeg import os import tempfile +import subprocess +import json __all__ = ['cache_video', 'cache_image', 'str2bool'] @@ -34,21 +36,6 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) -def expand_slist(slist, num_inference_steps ): - new_slist= [] - inc = len(slist) / num_inference_steps - pos = 0 - for i in range(num_inference_steps): - new_slist.append(slist[ int(pos)]) - pos += inc - return new_slist - -def update_loras_slists(trans, slists, num_inference_steps ): - from mmgp import offload - slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] - nos = [str(l) for l in range(len(slists))] - offload.activate_loras(trans, nos, slists ) - def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -141,10 +128,12 @@ def convert_image_to_video(image): return temp_video.name def resize_lanczos(img, h, w): - img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) + img = (img + 1).float().mul_(127.5) + img = Image.fromarray(np.clip(img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) img = img.resize((w,h), resample=Image.Resampling.LANCZOS) - return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) - + img = torch.from_numpy(np.array(img).astype(np.float32)).movedim(-1, 0) + img = img.div(127.5).sub_(1) + return img def remove_background(img, session=None): if session ==None: @@ -445,109 +434,180 @@ def create_progress_hook(filename): return progress_hook(block_num, block_size, total_size, filename) return hook + +import tempfile, os import ffmpeg -import os -import tempfile -def extract_audio_tracks(source_video, verbose=False, query_only= False): +def extract_audio_tracks(source_video, verbose=False, query_only=False): """ - Extract all audio tracks from source video to temporary files. - - Args: - source_video: Path to video with audio to extract - verbose: Enable verbose output (default: False) - + Extract all audio tracks from a source video into temporary AAC files. + Returns: - List of temporary audio file paths, or empty list if no audio tracks + Tuple: + - List of temp file paths for extracted audio tracks + - List of corresponding metadata dicts: + {'codec', 'sample_rate', 'channels', 'duration', 'language'} + where 'duration' is set to container duration (for consistency). """ + probe = ffmpeg.probe(source_video) + audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] + container_duration = float(probe['format'].get('duration', 0.0)) + + if not audio_streams: + if query_only: return 0 + if verbose: print(f"No audio track found in {source_video}") + return [], [] + + if query_only: + return len(audio_streams) + + if verbose: + print(f"Found {len(audio_streams)} audio track(s), container duration = {container_duration:.3f}s") + + file_paths = [] + metadata = [] + + for i, stream in enumerate(audio_streams): + fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') + os.close(fd) + + file_paths.append(temp_path) + metadata.append({ + 'codec': stream.get('codec_name'), + 'sample_rate': int(stream.get('sample_rate', 0)), + 'channels': int(stream.get('channels', 0)), + 'duration': container_duration, + 'language': stream.get('tags', {}).get('language', None) + }) + + ffmpeg.input(source_video).output( + temp_path, + **{f'map': f'0:a:{i}', 'acodec': 'aac', 'b:a': '128k'} + ).overwrite_output().run(quiet=not verbose) + + return file_paths, metadata + + +import subprocess + +def combine_and_concatenate_video_with_audio_tracks( + save_path_tmp, video_path, + source_audio_tracks, new_audio_tracks, + source_audio_duration, audio_sampling_rate, + new_audio_from_start=False, + source_audio_metadata=None, + audio_bitrate='128k', + audio_codec='aac' +): + inputs, filters, maps, idx = ['-i', video_path], [], ['-map', '0:v'], 1 + metadata_args = [] + sources = source_audio_tracks or [] + news = new_audio_tracks or [] + + duplicate_source = len(sources) == 1 and len(news) > 1 + N = len(news) if source_audio_duration == 0 else max(len(sources), len(news)) or 1 + + for i in range(N): + s = (sources[i] if i < len(sources) + else sources[0] if duplicate_source else None) + n = news[i] if len(news) == N else (news[0] if news else None) + + if source_audio_duration == 0: + if n: + inputs += ['-i', n] + filters.append(f'[{idx}:a]apad=pad_dur=100[aout{i}];') + idx += 1 + else: + filters.append(f'anullsrc=r={audio_sampling_rate}:cl=mono,apad=pad_dur=100[aout{i}];') + else: + if s: + inputs += ['-i', s] + meta = source_audio_metadata[i] if source_audio_metadata and i < len(source_audio_metadata) else {} + needs_filter = ( + meta.get('codec') != audio_codec or + meta.get('sample_rate') != audio_sampling_rate or + meta.get('channels') != 1 or + meta.get('duration', 0) < source_audio_duration + ) + if needs_filter: + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}];') + else: + filters.append( + f'[{idx}:a]apad=pad_dur={source_audio_duration},atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}];') + if lang := meta.get('language'): + metadata_args += ['-metadata:s:a:' + str(i), f'language={lang}'] + idx += 1 + else: + filters.append( + f'anullsrc=r={audio_sampling_rate}:cl=mono,atrim=0:{source_audio_duration},asetpts=PTS-STARTPTS[s{i}];') + + if n: + inputs += ['-i', n] + start = '0' if new_audio_from_start else source_audio_duration + filters.append( + f'[{idx}:a]aresample={audio_sampling_rate},aformat=channel_layouts=mono,' + f'atrim=start={start},asetpts=PTS-STARTPTS[n{i}];' + f'[s{i}][n{i}]concat=n=2:v=0:a=1[aout{i}];') + idx += 1 + else: + filters.append(f'[s{i}]apad=pad_dur=100[aout{i}];') + + maps += ['-map', f'[aout{i}]'] + + cmd = ['ffmpeg', '-y', *inputs, + '-filter_complex', ''.join(filters), + *maps, *metadata_args, + '-c:v', 'copy', + '-c:a', audio_codec, + '-b:a', audio_bitrate, + '-ar', str(audio_sampling_rate), + '-ac', '1', + '-shortest', save_path_tmp] + try: - # Check if source video has audio - probe = ffmpeg.probe(source_video) - audio_streams = [s for s in probe['streams'] if s['codec_type'] == 'audio'] - - if not audio_streams: - if query_only: return 0 - if verbose: - print(f"No audio track found in {source_video}") - return [] - if query_only: return len(audio_streams) - if verbose: - print(f"Found {len(audio_streams)} audio track(s)") - - # Create temporary audio files for each track - temp_audio_files = [] - for i in range(len(audio_streams)): - fd, temp_path = tempfile.mkstemp(suffix=f'_track{i}.aac', prefix='audio_') - os.close(fd) # Close file descriptor immediately - temp_audio_files.append(temp_path) - - # Extract each audio track - for i, temp_path in enumerate(temp_audio_files): - (ffmpeg - .input(source_video) - .output(temp_path, **{f'map': f'0:a:{i}', 'acodec': 'aac'}) - .overwrite_output() - .run(quiet=not verbose)) - - return temp_audio_files - - except ffmpeg.Error as e: - print(f"FFmpeg error during audio extraction: {e}") - return 0 if query_only else [] - except Exception as e: - print(f"Error during audio extraction: {e}") - return 0 if query_only else [] + subprocess.run(cmd, check=True, capture_output=True, text=True) + except subprocess.CalledProcessError as e: + raise Exception(f"FFmpeg error: {e.stderr}") -def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False): - """ - Combine video with audio tracks. Output duration matches video length exactly. - - Args: - target_video: Path to video to receive the audio - audio_tracks: List of audio file paths to combine - output_video: Path for the output video - verbose: Enable verbose output (default: False) - - Returns: - True if successful, False otherwise - """ + +import ffmpeg + + +import subprocess +import ffmpeg + +def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, + audio_metadata=None, verbose=False): if not audio_tracks: - if verbose: - print("No audio tracks to combine") - return False - - try: - # Get video duration to ensure exact alignment - video_probe = ffmpeg.probe(target_video) - video_duration = float(video_probe['streams'][0]['duration']) - - if verbose: - print(f"Target video duration: {video_duration:.3f} seconds") - - # Combine target video with all audio tracks, force video duration - video = ffmpeg.input(target_video).video - audio_inputs = [ffmpeg.input(audio_path).audio for audio_path in audio_tracks] - - # Create output with video duration as master timing - inputs = [video] + audio_inputs - (ffmpeg - .output(*inputs, output_video, - vcodec='copy', - acodec='copy', - t=video_duration) # Force exact video duration - .overwrite_output() - .run(quiet=not verbose)) - - if verbose: - print(f"Successfully created {output_video} with {len(audio_tracks)} audio track(s) aligned to video duration") - return True - - except ffmpeg.Error as e: - print(f"FFmpeg error during video combination: {e}") - return False - except Exception as e: - print(f"Error during video combination: {e}") - return False + if verbose: print("No audio tracks to combine."); return False + + dur = float(next(s for s in ffmpeg.probe(target_video)['streams'] + if s['codec_type'] == 'video')['duration']) + if verbose: print(f"Video duration: {dur:.3f}s") + + cmd = ['ffmpeg', '-y', '-i', target_video] + for path in audio_tracks: + cmd += ['-i', path] + + cmd += ['-map', '0:v'] + for i in range(len(audio_tracks)): + cmd += ['-map', f'{i+1}:a'] + + for i, meta in enumerate(audio_metadata or []): + if (lang := meta.get('language')): + cmd += ['-metadata:s:a:' + str(i), f'language={lang}'] + + cmd += ['-c:v', 'copy', '-c:a', 'copy', '-t', str(dur), output_video] + + result = subprocess.run(cmd, capture_output=not verbose, text=True) + if result.returncode != 0: + raise Exception(f"FFmpeg error:\n{result.stderr}") + if verbose: + print(f"Created {output_video} with {len(audio_tracks)} audio track(s)") + return True + def cleanup_temp_audio_files(audio_tracks, verbose=False): """ diff --git a/wgp.py b/wgp.py index 103d26e..6497adf 100644 --- a/wgp.py +++ b/wgp.py @@ -16,9 +16,9 @@ import json import wan from wan.utils import notification_sound from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS -from wan.utils.utils import expand_slist, update_loras_slists +from wan.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers from wan.utils.utils import cache_video, convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video -from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files, calculate_new_dimensions +from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, calculate_new_dimensions from wan.modules.attention import get_attention_modes, get_supported_attention_modes from huggingface_hub import hf_hub_download, snapshot_download @@ -50,8 +50,8 @@ global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.3" -WanGP_version = "7.5" +target_mmgp_version = "3.5.5" +WanGP_version = "7.6" settings_version = 2.23 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -183,14 +183,14 @@ def process_prompt_and_add_tasks(state, model_choice): inputs["model_filename"] = model_filename mode = inputs["mode"] - if mode == "edit": + if mode.startswith("edit_"): edit_video_source =gen.get("edit_video_source", None) edit_overrides =gen.get("edit_overrides", None) _ , _ , _, frames_count = get_video_info(edit_video_source) if frames_count > max_source_video_frames: gr.Info(f"Post processing is not supported on videos longer than {max_source_video_frames} frames. Output Video will be truncated") # return - for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "video_mask", "image_mask"]: + for k in ["image_start", "image_end", "image_refs", "video_guide", "audio_guide", "audio_guide2", "audio_source" , "video_mask", "image_mask"]: inputs[k] = None inputs.update(edit_overrides) del gen["edit_video_source"], gen["edit_overrides"] @@ -201,7 +201,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(spatial_upsampling) >0: prompt += ["Spatial Upsampling"] temporal_upsampling = inputs.get("temporal_upsampling","") if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] - if image_outputs and len(temporal_upsampling) > 0: + if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0: gr.Info("Temporal Upsampling can not be used with an Image") return film_grain_intensity = inputs.get("film_grain_intensity",0) @@ -209,14 +209,26 @@ def process_prompt_and_add_tasks(state, model_choice): # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] if film_grain_intensity >0: prompt += ["Film Grain"] MMAudio_setting = inputs.get("MMAudio_setting",0) - seed = inputs.get("seed",None) repeat_generation= inputs.get("repeat_generation",1) - if repeat_generation > 1 and (MMAudio_setting == 0 or seed != -1): - gr.Info("It is useless to generate more than one sample if you don't use MMAudio with a random seed") - return - if MMAudio_setting !=0: prompt += ["MMAudio"] + if mode =="edit_remux": + audio_source = inputs["audio_source"] + if MMAudio_setting== 1: + prompt += ["MMAudio"] + audio_source = None + inputs["audio_source"] = audio_source + else: + if audio_source is None: + gr.Info("You must provide a custom Audio") + return + prompt += ["Custom Audio"] + repeat_generation == 1 + + seed = inputs.get("seed",None) if len(prompt) == 0: - gr.Info("You must choose at least one Post Processing Method") + if mode=="edit_remux": + gr.Info("You must choose at least one Remux Method") + else: + gr.Info("You must choose at least one Post Processing Method") return inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) @@ -261,6 +273,7 @@ def process_prompt_and_add_tasks(state, model_choice): force_fps = inputs["force_fps"] audio_guide = inputs["audio_guide"] audio_guide2 = inputs["audio_guide2"] + audio_source = inputs["audio_source"] video_guide = inputs["video_guide"] image_guide = inputs["image_guide"] video_mask = inputs["video_mask"] @@ -280,6 +293,14 @@ def process_prompt_and_add_tasks(state, model_choice): MMAudio_setting = inputs["MMAudio_setting"] image_mode = inputs["image_mode"] switch_threshold = inputs["switch_threshold"] + loras_multipliers = inputs["loras_multipliers"] + activated_loras = inputs["activated_loras"] + + if len(loras_multipliers) > 0: + _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps) + if len(errors) > 0: + gr.Info(f"Error parsing Loras Multipliers: {errors}") + return if no_steps_skipping: skip_steps_cache_type = "" if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0: @@ -322,6 +343,9 @@ def process_prompt_and_add_tasks(state, model_choice): else: frames_positions = None + if audio_source is not None and MMAudio_setting != 0: + gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time") + return if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: gr.Info("The number of frames to keep must be a non null integer") @@ -354,13 +378,13 @@ def process_prompt_and_add_tasks(state, model_choice): if not "I" in video_prompt_type and not not "V" in video_prompt_type: gr.Info("To get good results with Multitalk and two people speaking, it is recommended to set a Reference Frame or a Control Video (potentially truncated) that contains the two people one on each side") - if len(filter_letters(image_prompt_type, "VL")) > 0 : - if "R" in audio_prompt_type: - gr.Info("Remuxing is not yet supported if there is a video source") - audio_prompt_type= audio_prompt_type.replace("R" ,"") - if "A" in audio_prompt_type: - gr.Info("Creating an Audio track is not yet supported if there is a video source") - return + # if len(filter_letters(image_prompt_type, "VL")) > 0 : + # if "R" in audio_prompt_type: + # gr.Info("Remuxing is not yet supported if there is a video source") + # audio_prompt_type= audio_prompt_type.replace("R" ,"") + # if "A" in audio_prompt_type: + # gr.Info("Creating an Audio track is not yet supported if there is a video source") + # return if model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: if image_refs == None : @@ -383,19 +407,23 @@ def process_prompt_and_add_tasks(state, model_choice): image_refs = None if "V" in video_prompt_type: - if video_guide is None and image_guide is None: - if image_outputs: + if image_outputs: + if image_guide is None: gr.Info("You must provide a Control Image") - else: - gr.Info("You must provide a Control Video") - return - if "A" in video_prompt_type and not "U" in video_prompt_type: - if video_mask is None and image_mask is None: - if image_outputs: - gr.Info("You must provide a Image Mask") - else: - gr.Info("You must provide a Video Mask") return + else: + if video_guide is None: + gr.Info("You must provide a Control Video") + return + if "A" in video_prompt_type and not "U" in video_prompt_type: + if image_outputs: + if image_mask is None: + gr.Info("You must provide a Image Mask") + return + else: + if video_mask is None: + gr.Info("You must provide a Video Mask") + return else: video_mask = None image_mask = None @@ -419,6 +447,13 @@ def process_prompt_and_add_tasks(state, model_choice): keep_frames_video_guide = "" denoising_strength = 1.0 + if image_outputs: + video_guide = None + video_mask = None + else: + image_guide = None + image_mask = None + if "S" in image_prompt_type: if image_start == None or isinstance(image_start, list) and len(image_start) == 0: @@ -489,6 +524,7 @@ def process_prompt_and_add_tasks(state, model_choice): "image_refs": image_refs, "audio_guide": audio_guide, "audio_guide2": audio_guide2, + "audio_source": audio_source, "video_guide": video_guide, "image_guide": image_guide, "video_mask": video_mask, @@ -705,7 +741,7 @@ def save_queue_action(state): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source"] for key in image_keys: images_pil = params_copy.get(key) @@ -881,7 +917,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): params['state'] = state image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source"] loaded_pil_images = {} loaded_video_paths = {} @@ -1103,7 +1139,7 @@ def autosave_queue(): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs", "image_guide", "image_mask"] - video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2"] + video_keys = ["video_guide", "video_mask", "video_source", "audio_guide", "audio_guide2", "audio_source" ] for key in image_keys: images_pil = params_copy.get(key) @@ -1829,7 +1865,14 @@ def get_model_fps(model_type): return fps def get_computed_fps(force_fps, base_model_type , video_guide, video_source ): - if force_fps == "control" and video_guide != None: + if force_fps == "auto": + if video_source != None: + fps, _, _, _ = get_video_info(video_source) + elif video_guide != None: + fps, _, _, _ = get_video_info(video_guide) + else: + fps = get_model_fps(base_model_type) + elif force_fps == "control" and video_guide != None: fps, _, _, _ = get_video_info(video_guide) elif force_fps == "source" and video_source != None: fps, _, _, _ = get_video_info(video_source) @@ -1979,7 +2022,8 @@ def fix_settings(model_type, ui_defaults): video_prompt_type = video_prompt_type.replace("I", "KI") if remove_background_images_ref != 0: remove_background_images_ref = 1 - ui_defaults["remove_background_images_ref"] = remove_background_images_ref + if model_type in ["hunyuan_avatar"]: remove_background_images_ref = 0 + ui_defaults["remove_background_images_ref"] = remove_background_images_ref ui_defaults["video_prompt_type"] = video_prompt_type @@ -2103,6 +2147,7 @@ def get_default_settings(model_type): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 5, + "remove_background_images_ref": 0, "skip_steps_start_step_perc": 25, "video_length": 129, "video_prompt_type": "I", @@ -2265,11 +2310,48 @@ if args.compile: #args.fastest or compile="transformer" lock_ui_compile = True -def save_quantized_model(model, model_type, model_filename, dtype, config_file): + +def save_model(model, model_type, dtype, config_file, submodel_no = 1): + model_def = get_model_def(model_type) + if model_def == None: return + url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) + URLs= model_def.get(url_key, None) + if URLs is None: return + if isinstance(URLs, str): + print("Unable to save model for a finetune that references external files") + return + from mmgp import offload + if dtype == torch.bfloat16: + dtypestr= "bf16" + else: + dtypestr= "fp16" + model_filename = None + for url in URLs: + if "quanto" not in url and dtypestr in url: + model_filename = os.path.basename(url) + break + if model_filename is None: + print(f"No target filename mentioned in {url_key}") + return + if not os.path.isfile(model_filename): + offload.save_model(model, os.path.join("ckpts",model_filename), config_file_path=config_file) + print(f"New model file '{model_filename}' had been created for finetune Id '{model_type}'.") + finetune_file = os.path.join(os.path.dirname(model_def["path"]) , model_type + ".json") + with open(finetune_file, 'r', encoding='utf-8') as reader: + saved_finetune_def = json.load(reader) + del saved_finetune_def["model"]["source"] + del model_def["source"] + with open(finetune_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(saved_finetune_def, indent=4)) + print(f"The 'source' entry has been removed in the '{finetune_file}' definition file.") + +def save_quantized_model(model, model_type, model_filename, dtype, config_file, submodel_no = 1): if "quanto" in model_filename: return model_def = get_model_def(model_type) if model_def == None: return - URLs= model_def["URLs"] + url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) + URLs= model_def.get(url_key, None) + if URLs is None: return if isinstance(URLs, str): print("Unable to create a quantized model for a finetune that references external files") return @@ -2297,7 +2379,7 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file) finetune_file = os.path.join(os.path.dirname(model_def["path"]) , model_type + ".json") with open(finetune_file, 'r', encoding='utf-8') as reader: saved_finetune_def = json.load(reader) - saved_finetune_def["model"]["URLs"] = URLs + saved_finetune_def["model"][url_key] = URLs with open(finetune_file, "w", encoding="utf-8") as writer: writer.write(json.dumps(saved_finetune_def, indent=4)) print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") @@ -2416,8 +2498,13 @@ def download_models(model_filename, model_type, submodel_no = 1): model_family = get_model_family(model_type) model_def = get_model_def(model_type) + source = model_def.get("source", None) + + key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" - if not model_type in modules_files: + if source is not None: + model_filename = None + elif not model_type in modules_files: if not os.path.isfile(model_filename ): URLs = get_model_recursive_prop(model_type, key_name, return_list= False) if isinstance(URLs, str): @@ -3107,7 +3194,7 @@ def refresh_gallery(state): #, msg base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) is_image = model_def.get("image_outputs", False) - onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0 + onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0 and not params.get("mode","").startswith("edit_") enhanced = False if prompt.startswith("!enhanced!\n"): enhanced = True @@ -3330,7 +3417,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ - and (any_letters(video_video_prompt_type, "VFK") or any_letters(video_image_prompt_type, "VL")) : + and (any_letters(video_video_prompt_type, "VFK") ) : video_video_guide_outpainting = video_video_guide_outpainting.split(" ") video_outpainting = f"Top={video_video_guide_outpainting[0]}%, Bottom={video_video_guide_outpainting[1]}%, Left={video_video_guide_outpainting[2]}%, Right={video_video_guide_outpainting[3]}%" video_num_inference_steps = configs.get("num_inference_steps", 0) @@ -3409,7 +3496,7 @@ def select_video(state, input_file_list, event_data: gr.EventData): else: html = get_default_video_info() visible= len(file_list) > 0 - return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) + return choice, html, gr.update(visible=visible and not is_image) , gr.update(visible=visible and is_image), gr.update(visible=visible and not is_image) , gr.update(visible=visible and not is_image) def convert_image(image): @@ -3739,10 +3826,6 @@ def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_can return torch.stack(torch_frames) -def update_loras_slists(trans, slists, num_inference_steps ): - slists = [ expand_slist(slist, num_inference_steps ) if isinstance(slist, list) else slist for slist in slists ] - nos = [str(l) for l in range(len(slists))] - offload.activate_loras(trans, nos, slists ) def parse_keep_frames_video_guide(keep_frames, video_length): @@ -3817,7 +3900,6 @@ def perform_spatial_upsampling(sample, spatial_upsampling): scale = 1.5 else: scale = 2 - sample = (sample + 1) / 2 h, w = sample.shape[-2:] h *= scale h = round(h/16) * 16 @@ -3830,7 +3912,6 @@ def perform_spatial_upsampling(sample, spatial_upsampling): return resize_lanczos(frame, h, w).unsqueeze(1) sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) frames_to_upsample = None - sample.mul_(2).sub_(1) return sample def any_audio_track(model_type): @@ -3852,13 +3933,6 @@ def get_available_filename(target_path, video_source, suffix = "", force_extensi return full_path counter += 1 -def preparse_loras_multipliers(loras_multipliers): - loras_multipliers = loras_multipliers.strip(" \r\n") - loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") - loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] - loras_multipliers = " ".join(loras_mult_choices_list) - return loras_multipliers.split(" ") - def set_seed(seed): import random seed = random.randint(0, 99999999) if seed == None or seed < 0 else seed @@ -3872,6 +3946,7 @@ def set_seed(seed): def edit_video( send_cmd, state, + mode, video_source, seed, temporal_upsampling, @@ -3882,6 +3957,7 @@ def edit_video( MMAudio_prompt, MMAudio_neg_prompt, repeat_generation, + audio_source, **kwargs ): @@ -3900,8 +3976,11 @@ def edit_video( has_already_audio = False audio_tracks = [] if MMAudio_setting == 0: - audio_tracks = extract_audio_tracks(video_source) + audio_tracks, audio_metadata = extract_audio_tracks(video_source) has_already_audio = len(audio_tracks) > 0 + + if audio_source is not None: + audio_tracks = [audio_source] with lock: file_list = gen["file_list"] @@ -3967,11 +4046,18 @@ def edit_video( repeat_no +=1 gen["repeat_no"] = repeat_no suffix = "" if "_post" in video_source else "_post" + + if audio_source is not None: + audio_prompt_type = configs.get("audio_prompt_type", "") + if not "T" in audio_prompt_type:audio_prompt_type += "T" + configs["audio_prompt_type"] = audio_prompt_type + any_change = True + if any_mmaudio: send_cmd("progress", [0, get_latest_status(state,"MMAudio Soundtrack Generation")]) from postprocessing.mmaudio.mmaudio import video_to_audio new_video_path = get_available_filename(save_path, video_source, suffix) - video_to_audio(video_path, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= frames_count /output_fps, video_save_path = new_video_path , persistent_models = server_config.get("mmaudio_enabled", 0) == 2, verboseLevel = verbose_level) + video_to_audio(video_path, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= frames_count /output_fps, save_path = new_video_path , persistent_models = server_config.get("mmaudio_enabled", 0) == 2, verboseLevel = verbose_level) configs["MMAudio_setting"] = MMAudio_setting configs["MMAudio_prompt"] = MMAudio_prompt configs["MMAudio_neg_prompt"] = MMAudio_neg_prompt @@ -3980,7 +4066,7 @@ def edit_video( elif len(audio_tracks) > 0: # combine audio files and new video file new_video_path = get_available_filename(save_path, video_source, suffix) - combine_video_with_audio_tracks(video_path, audio_tracks, new_video_path) + combine_video_with_audio_tracks(video_path, audio_tracks, new_video_path, audio_metadata=audio_metadata) else: new_video_path = video_path if tmp_path != None: @@ -4061,6 +4147,7 @@ def generate_video( mask_expand, audio_guide, audio_guide2, + audio_source, audio_prompt_type, speakers_locations, sliding_window_size, @@ -4107,8 +4194,8 @@ def generate_video( global wan_model, offloadobj, reload_needed, save_path gen = get_gen_info(state) torch.set_grad_enabled(False) - if mode == "edit": - edit_video(send_cmd, state, video_source, seed, temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation) + if mode.startswith("edit_"): + edit_video(send_cmd, state, mode, video_source, seed, temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation, MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, repeat_generation, audio_source) return with lock: file_list = gen["file_list"] @@ -4118,7 +4205,6 @@ def generate_video( model_def = get_model_def(model_type) is_image = image_mode == 1 if is_image: - # min_frames_if_references = server_config.get("min_frames_if_references", 5) video_length = min_frames_if_references if "I" in video_prompt_type else 1 else: batch_size = 1 @@ -4127,12 +4213,12 @@ def generate_video( if image_guide is not None and isinstance(image_guide, Image.Image): video_guide = convert_image_to_video(image_guide) temp_filenames_list.append(video_guide) - image_guide = None + image_guide = None if image_mask is not None and isinstance(image_mask, Image.Image): video_mask = convert_image_to_video(image_mask) temp_filenames_list.append(video_mask) - image_mask = None + image_mask = None fit_canvas = server_config.get("fit_canvas", 0) @@ -4183,59 +4269,28 @@ def generate_video( prompts = [part for part in prompts if len(prompt)>0] parsed_keep_frames_video_source= max_source_video_frames if len(keep_frames_video_source) ==0 else int(keep_frames_video_source) - - loras = state["loras"] - loras_slists = [] transformer_loras_filenames, transformer_loras_multipliers = get_transformer_loras(model_type) + if transformer_loras_filenames != None: + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(transformer_loras_multipliers, len(transformer_loras_filenames), num_inference_steps) + if len(errors) > 0: raise Exception(f"Error parsing Transformer Loras: {errors}") + loras_selected = transformer_loras_filenames + if hasattr(wan_model, "get_loras_transformer"): extra_loras_transformers, extra_loras_multipliers = wan_model.get_loras_transformer(get_model_recursive_prop, **locals()) - transformer_loras_filenames += extra_loras_transformers - transformer_loras_multipliers += extra_loras_multipliers - if len(loras) > 0 or len(transformer_loras_filenames) > 0 : - def is_float(element: any) -> bool: - if element is None: - return False - try: - float(element) - return True - except ValueError: - return False - loras_list_mult_choices_nums = [] - loras_multipliers = loras_multipliers.strip(" \r\n") - if len(loras_multipliers) > 0: - list_mult_choices_list = preparse_loras_multipliers(loras_multipliers) - for i, mult in enumerate(list_mult_choices_list): - mult = mult.strip() - if "," in mult: - multlist = mult.split(",") - slist = [] - for smult in multlist: - if not is_float(smult): - raise gr.Error(f"Lora sub value no {i+1} ({smult}) in Multiplier definition '{multlist}' is invalid") - slist.append(float(smult)) - loras_slists.append(slist) - slist = expand_slist(slist, num_inference_steps ) - loras_list_mult_choices_nums.append(slist) - else: - if not is_float(mult): - raise gr.Error(f"Lora Multiplier no {i+1} ({mult}) is invalid") - mult = float(mult) - loras_slists.append(mult) - loras_list_mult_choices_nums.append(mult) - if len(loras_list_mult_choices_nums ) < len(activated_loras): - loras_list_mult_choices_nums += [1.0] * ( len(activated_loras) - len(loras_list_mult_choices_nums ) ) - if len(loras_slists ) < len(activated_loras): - loras_slists += [1.0] * ( len(activated_loras) - len(loras_slists ) ) - lora_dir = get_lora_dir(model_type) - loras_selected = [ os.path.join(lora_dir, lora) for lora in activated_loras] + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(extra_loras_multipliers, len(extra_loras_transformers), num_inference_steps, merge_slist= loras_slists ) + if len(errors) > 0: raise Exception(f"Error parsing Extra Transformer Loras: {errors}") + loras_selected += extra_loras_transformers + loras = state["loras"] + if len(loras) > 0: + loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, merge_slist= loras_slists ) + if len(errors) > 0: raise Exception(f"Error parsing Loras: {errors}") + lora_dir = get_lora_dir(model_type) + loras_selected += [ os.path.join(lora_dir, lora) for lora in activated_loras] + + if len(loras_selected) > 0: pinnedLora = profile !=5 # and transformer_loras_filenames == None False # # # split_linear_modules_map = getattr(trans,"split_linear_modules_map", None) - if transformer_loras_filenames != None: - loras_selected = transformer_loras_filenames + loras_selected - loras_list_mult_choices_nums = transformer_loras_multipliers + loras_list_mult_choices_nums - loras_slists = transformer_loras_multipliers + loras_slists - offload.load_loras_into_model(trans , loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map) errors = trans._loras_errors if len(errors) > 0: @@ -4283,12 +4338,15 @@ def generate_video( video_source = max(mp4_files, key=os.path.getmtime) if mp4_files else None fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) - control_audio_tracks = [] - if "R" in audio_prompt_type and video_guide != None and video_source == None and MMAudio_setting == 0 and not any_audio_track(base_model_type): - control_audio_tracks = extract_audio_tracks(video_guide) + control_audio_tracks = source_audio_tracks = source_audio_metadata = [] + if "R" in audio_prompt_type and video_guide is not None and MMAudio_setting == 0 and not any_letters(audio_prompt_type, "ABX"): + control_audio_tracks, _ = extract_audio_tracks(video_guide) + if video_source is not None: + source_audio_tracks, source_audio_metadata = extract_audio_tracks(video_source) + reset_control_aligment = "T" in video_prompt_type if test_any_sliding_window(model_type) : - if video_source != None: + if video_source is not None: current_video_length += sliding_window_overlap sliding_window = current_video_length > sliding_window_size reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) @@ -4303,8 +4361,8 @@ def generate_video( any_background_ref = False outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] - if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace or flux): - frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else [] + if image_refs is not None and len(image_refs) > 0: + frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else [] frames_positions_list = frames_positions_list[:len(image_refs)] nb_frames_positions = len(frames_positions_list) if nb_frames_positions > 0: @@ -4319,7 +4377,6 @@ def generate_video( default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) fit_canvas = None if len(image_refs) > nb_frames_positions: - if hunyuan_avatar: remove_background_images_ref = 0 any_background_ref = "K" in video_prompt_type if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) @@ -4373,14 +4430,9 @@ def generate_video( trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] else: raise gr.Error("Teacache not supported for this model") - source_video = None - target_camera = None - merged_audio_data = None - if recam: - source_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= current_video_length, start_frame = 0, fit_canvas= fit_canvas == 1) - target_camera = model_mode - - source_audio = audio_guide + output_new_audio_data = None + output_new_audio_filepath = None + original_audio_guide = audio_guide audio_proj_split = None audio_proj_full = None audio_scale = None @@ -4390,26 +4442,30 @@ def generate_video( import librosa duration = librosa.get_duration(path=audio_guide) combination_type = "add" - if audio_guide2 != None: + if audio_guide2 is not None: duration2 = librosa.get_duration(path=audio_guide2) if "C" in audio_prompt_type: duration += duration2 else: duration = min(duration, duration2) combination_type = "para" if "P" in audio_prompt_type else "add" - elif "X" in audio_prompt_type: - from preprocessing.speakers_separator import extract_dual_audio - combination_type = "para" - if args.save_speakers: - audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav" - else: - audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav") - extract_dual_audio(source_audio, audio_guide, audio_guide2 ) - current_video_length = min(int(fps * duration // 4) * 4 + 5, current_video_length) + else: + if "X" in audio_prompt_type: + from preprocessing.speakers_separator import extract_dual_audio + combination_type = "para" + if args.save_speakers: + audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav" + else: + audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav") + extract_dual_audio(original_audio_guide, audio_guide, audio_guide2 ) + output_new_audio_filepath = original_audio_guide + current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length) if fantasy: - audio_proj_split_full, audio_context_lens_full = parse_audio(audio_guide, num_frames= current_video_length, fps= fps, device= processing_device ) + # audio_proj_split_full, audio_context_lens_full = parse_audio(audio_guide, num_frames= max_source_video_frames, fps= fps, padded_frames_for_embeddings= (reuse_frames if reset_control_aligment else 0), device= processing_device ) audio_scale = 1.0 elif multitalk: from wan.multitalk.multitalk import get_full_audio_embeddings - audio_proj_full, merged_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= current_video_length, sr= audio_sampling_rate, fps =fps) + # pad audio_proj_full if aligned to beginning of window to simulate source window overlap + audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0)) + if output_new_audio_filepath is not None: output_new_audio_data = None if not args.save_speakers and "X" in audio_prompt_type: os.remove(audio_guide) os.remove(audio_guide2) @@ -4531,59 +4587,60 @@ def generate_video( gen["window_no"] = window_no return_latent_slice = None - window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : None, "image_mask" : None} - if fantasy: - window_latent_start_frame = (window_start_frame ) // latent_size - window_latent_size= (current_video_length - 1) // latent_size + 1 - audio_proj_split = audio_proj_split_full[:, window_latent_start_frame:window_latent_start_frame + window_latent_size].clone() - audio_context_lens = audio_context_lens_full[window_latent_start_frame:window_latent_start_frame + window_latent_size].clone() - if multitalk: - from wan.multitalk.multitalk import get_window_audio_embeddings - audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= window_start_frame, clip_length = current_video_length) - if i2v and window_no > 1: - src_video = pre_video_guide - if hunyuan_custom or hunyuan_avatar or flux: - src_ref_images = image_refs - elif phantom: - src_ref_images = image_refs.copy() if image_refs != None else None - elif window_no == 1 and (video_source != None and len(video_source) > 0 or image_start != None) and (diffusion_forcing or ltxv or vace): - if "L" in image_prompt_type: - from wan.utils.utils import get_video_frame - refresh_preview["video_source"] = get_video_frame(video_source, 0) - if image_start != None: + + src_ref_images = image_refs + image_start_tensor = image_end_tensor = None + if window_no == 1 and (video_source is not None or image_start is not None): + if image_start is not None: new_height, new_width = calculate_new_dimensions(height, width, image_start.height, image_start.width, fit_canvas, 32) - prefix_video = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - prefix_video = torch.from_numpy(np.array(prefix_video).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0).unsqueeze(1) - pre_video_guide = prefix_video - image_start = None + image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_start_tensor = torch.from_numpy(np.array(image_start_tensor).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) + pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) + if image_end is not None: + image_end_tensor = image_end.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_end_tensor = torch.from_numpy(np.array(image_end_tensor).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) else: + if "L" in image_prompt_type: + from wan.utils.utils import get_video_frame + refresh_preview["video_source"] = get_video_frame(video_source, 0) prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = 32 if ltxv else 16) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w pre_video_guide = prefix_video[:, -reuse_frames:] source_video_overlap_frames_count = pre_video_guide.shape[1] source_video_frames_count = prefix_video.shape[1] - if vace or ltxv: - if sample_fit_canvas != None: image_size = pre_video_guide.shape[-2:] - guide_start_frame = prefix_video.shape[1] + if sample_fit_canvas != None: image_size = pre_video_guide.shape[-2:] + guide_start_frame = prefix_video.shape[1] sample_fit_canvas = None + window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames) + alignment_shift = source_video_frames_count if reset_control_aligment else 0 + aligned_guide_start_frame = guide_start_frame - alignment_shift + aligned_guide_end_frame = guide_end_frame - alignment_shift + aligned_window_start_frame = window_start_frame - alignment_shift + if fantasy: + audio_proj_split , audio_context_lens = parse_audio(audio_guide, start_frame = aligned_window_start_frame, num_frames= current_video_length, fps= fps, device= processing_device ) + if multitalk: + from wan.multitalk.multitalk import get_window_audio_embeddings + # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) + audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) - if (vace or t2v or ltxv) and video_guide != None : + if video_guide is not None: keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_end_frame ] - if ltxv and video_guide != None: + keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] + + if ltxv and video_guide is not None: preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) # start one frame ealier to faciliate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if guide_start_frame == 0 else 1), start_frame = guide_start_frame - (0 if guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =32 ) + src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =32 ) if src_video != None: src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) @@ -4594,7 +4651,7 @@ def generate_video( sample_fit_canvas = None if t2v and "G" in video_prompt_type: - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) if video_guide_processed == None: src_video = pre_video_guide else: @@ -4624,11 +4681,11 @@ def generate_video( elif len(extra_process_list) == 2: status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] + context_scale = [ control_net_weight /2, control_net_weight2 /2] send_cmd("progress", [0, get_latest_status(state, status_info)]) - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) if video_guide_processed != None: if sample_fit_canvas != None: @@ -4639,14 +4696,14 @@ def generate_video( refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] if video_mask_processed != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame] + frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame] src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], current_video_length, image_size = image_size, device ="cpu", keep_video_guide_frames=keep_frames_parsed, - start_frame = guide_start_frame, + start_frame = aligned_guide_start_frame, pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], fit_into_canvas = sample_fit_canvas, inject_frames= frames_to_inject_parsed, @@ -4706,15 +4763,14 @@ def generate_video( try: samples = wan_model.generate( input_prompt = prompt, - image_start = image_start, - image_end = image_end if image_end != None else None, - input_frames = src_video, + image_start = image_start_tensor, + image_end = image_end_tensor, + input_frames = src_video, input_ref_images= src_ref_images, input_masks = src_mask, - input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video, + input_video= pre_video_guide, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, - target_camera= target_camera, frame_num= (current_video_length // latent_size)* latent_size + 1, batch_size = batch_size, height = height, @@ -4768,12 +4824,12 @@ def generate_video( offloadobj = offloadobj, ) except Exception as e: - if len(control_audio_tracks) > 0: - cleanup_temp_audio_files(control_audio_tracks) + if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: + cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks) remove_temp_filenames(temp_filenames_list) offloadobj.unload_all() offload.unload_loras_from_model(trans) - if trans is not None: offload.unload_loras_from_model(trans2) + if trans is not None: offload.unload_loras_from_model(trans) # if compile: # cache_size = torch._dynamo.config.cache_size_limit # torch.compiler.reset() @@ -4898,27 +4954,34 @@ def generate_video( new_video_path.append(img_path) img.save(img_path) video_path= new_video_path - elif len(control_audio_tracks) > 0 or source_audio != None or any_mmaudio or merged_audio_data is not None: + elif len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0 or output_new_audio_filepath is not None or any_mmaudio or output_new_audio_data is not None or audio_source is not None: save_path_tmp = video_path[:-4] + "_tmp.mp4" cache_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) - if len(control_audio_tracks) > 0: - combine_video_with_audio_tracks(save_path_tmp, control_audio_tracks, video_path ) - elif any_mmaudio: + output_new_audio_temp_filepath = None + new_audio_from_start = reset_control_aligment + source_audio_duration = source_video_frames_count / fps + if any_mmaudio: send_cmd("progress", [0, get_latest_status(state,"MMAudio Soundtrack Generation")]) from postprocessing.mmaudio.mmaudio import video_to_audio - video_to_audio(save_path_tmp, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= sample.shape[1] /fps, video_save_path = video_path, persistent_models = server_config.get("mmaudio_enabled", 0) == 2, verboseLevel = verbose_level) - else: - if merged_audio_data is not None: - import soundfile as sf - output_audio_path = get_available_filename(save_path, f"tmp{time_flag}.wav" ) - sf.write(output_audio_path, merged_audio_data, audio_sampling_rate) - else: - output_audio_path = None - final_command = [ "ffmpeg", "-y", "-i", save_path_tmp, "-i", source_audio if output_audio_path == None else output_audio_path, "-c:v", "libx264", "-c:a", "aac", "-shortest", "-loglevel", "warning", "-nostats", video_path, ] - import subprocess - subprocess.run(final_command, check=True) - if output_audio_path != None: os.remove(output_audio_path) + output_new_audio_filepath = output_new_audio_temp_filepath = get_available_filename(save_path, f"tmp{time_flag}.wav" ) + video_to_audio(save_path_tmp, prompt = MMAudio_prompt, negative_prompt = MMAudio_neg_prompt, seed = seed, num_steps = 25, cfg_strength = 4.5, duration= sample.shape[1] /fps, save_path = output_new_audio_filepath, persistent_models = server_config.get("mmaudio_enabled", 0) == 2, audio_file_only = True, verboseLevel = verbose_level) + new_audio_from_start = False + elif audio_source is not None: + output_new_audio_filepath = audio_source + new_audio_from_start = True + elif output_new_audio_data is not None: + import soundfile as sf + output_new_audio_filepath = output_new_audio_temp_filepath = get_available_filename(save_path, f"tmp{time_flag}.wav" ) + sf.write(output_new_audio_filepath, output_new_audio_data, audio_sampling_rate) + if output_new_audio_filepath is not None: + new_audio_tracks = [output_new_audio_filepath] + else: + new_audio_tracks = control_audio_tracks + + combine_and_concatenate_video_with_audio_tracks(video_path, save_path_tmp, source_audio_tracks, new_audio_tracks, source_audio_duration, audio_sampling_rate, new_audio_from_start = new_audio_from_start, source_audio_metadata= source_audio_metadata ) os.remove(save_path_tmp) + if output_new_audio_temp_filepath is not None: os.remove(output_new_audio_temp_filepath) + else: cache_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1)) @@ -4986,8 +5049,8 @@ def generate_video( if not trans2 is None: offload.unload_loras_from_model(trans2) - if len(control_audio_tracks) > 0: - cleanup_temp_audio_files(control_audio_tracks) + if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: + cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks) remove_temp_filenames(temp_filenames_list) @@ -5520,7 +5583,8 @@ def cancel_lset(): def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): - lset_name = os.path.splitext(lset_name)[0] + if lset_name.endswith(".json") or lset_name.endswith(".lset"): + lset_name = os.path.splitext(lset_name)[0] loras_presets = state["loras_presets"] loras = state["loras"] @@ -5590,7 +5654,8 @@ def delete_lset(state, lset_name): if len(lset_name) > 0: lset_name_filename = os.path.join( get_lora_dir(state["model_type"]), sanitize_file_name(lset_name)) if not os.path.isfile(lset_name_filename): - raise gr.Error(f"Preset '{lset_name}' not found ") + gr.Info(f"Preset '{lset_name}' not found ") + return [gr.update()]*7 os.remove(lset_name_filename) lset_choices = compute_lset_choices(loras_presets) pos = next( (i for i, item in enumerate(lset_choices) if item[1]==lset_name ), -1) @@ -5604,7 +5669,7 @@ def delete_lset(state, lset_name): lset_choices = compute_lset_choices(loras_presets) lset_choices.append((get_new_preset_msg(), "")) - selected_lset_name = "" if pos < -1 else lset_choices[pos][1] + selected_lset_name = "" if pos < 0 else lset_choices[min(pos, len(lset_choices)-1)][1] return gr.Dropdown(choices=lset_choices, value= selected_lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) def refresh_lora_list(state, lset_name, loras_choices): @@ -5849,7 +5914,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if "lset_name" in inputs: inputs.pop("lset_name") - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "image_guide", "video_source", "video_mask", "image_mask", "audio_guide", "audio_guide2"] + unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "image_guide", "video_source", "video_mask", "image_mask", "audio_guide", "audio_guide2", "audio_source"] for k in unsaved_params: inputs.pop(k) if model_type == None: model_type = state["model_type"] @@ -5979,7 +6044,7 @@ def image_to_ref_image_set(state, input_file_list, choice, target, target_name): return file_list[choice] -def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation): +def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation): gen = get_gen_info(state) file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) : @@ -5993,18 +6058,38 @@ def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling "spatial_upsampling":PP_spatial_upsampling, "film_grain_intensity": PP_film_grain_intensity, "film_grain_saturation": PP_film_grain_saturation, - "MMAudio_setting" : PP_MMAudio_setting, - "MMAudio_prompt" : PP_MMAudio_prompt, - "MMAudio_neg_prompt": PP_MMAudio_neg_prompt, - "seed": PP_MMAudio_seed, - "repeat_generation": PP_repeat_generation, } gen["edit_video_source"] = file_list[choice] gen["edit_overrides"] = overrides in_progress = gen.get("in_progress", False) - return "edit", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update() + return "edit_postprocessing", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update() + + +def remux_audio(state, input_file_list, choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio): + gen = get_gen_info(state) + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list) : + return gr.update(), gr.update(), gr.update() + + if not file_list[choice].endswith(".mp4"): + gr.Info("Post processing is only available with Videos") + return gr.update(), gr.update(), gr.update() + overrides = { + "MMAudio_setting" : PP_MMAudio_setting, + "MMAudio_prompt" : PP_MMAudio_prompt, + "MMAudio_neg_prompt": PP_MMAudio_neg_prompt, + "seed": PP_MMAudio_seed, + "repeat_generation": PP_repeat_generation, + "audio_source": PP_custom_audio, + } + + gen["edit_video_source"] = file_list[choice] + gen["edit_overrides"] = overrides + + in_progress = gen.get("in_progress", False) + return "edit_remux", get_unique_id() if not in_progress else gr.update(), get_unique_id() if in_progress else gr.update() def eject_video_from_gallery(state, input_file_list, choice): @@ -6278,6 +6363,7 @@ def save_inputs( mask_expand, audio_guide, audio_guide2, + audio_source, audio_prompt_type, speakers_locations, sliding_window_size, @@ -6520,6 +6606,11 @@ def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_ image_outputs = image_mode == 1 return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible ) +def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide): + video_prompt_type = del_in_sequence(video_prompt_type, "T") + video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) + return video_prompt_type + def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode): video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) @@ -6890,7 +6981,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif diffusion_forcing or ltxv: image_prompt_type_value= ui_defaults.get("image_prompt_type","T") # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) - image_prompt_type = gr.Radio( [("Text Prompt Only", "T"),("Start Video with Image", "S"),("Continue Video", "V")], value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) + image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")] + if ltxv: + image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] + image_prompt_type_choices += [("Continue Video", "V")] + image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) image_start = gr.Gallery(preview= True, @@ -6944,8 +7039,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if test_class_i2v(model_type) or hunyuan_i2v: # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) image_prompt_type_value= ui_defaults.get("image_prompt_type","S" ) - image_prompt_type_choices = [("Use only a Start Image", "S")] + image_prompt_type_choices = [("Start Video with Image", "S")] image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] + if not hunyuan_i2v: + image_prompt_type_choices += [("Continue Video", "V")] + image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) any_start_image = True any_end_image = True @@ -6956,14 +7054,19 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_end = gr.Gallery(preview= True, label="Images as ending points for new videos", type ="pil", #file_types= "image", columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + if hunyuan_i2v: + video_source = gr.Video(value=None, visible=False) + else: + video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + any_video_source = True else: image_prompt_type = gr.Radio(choices=[("", "")], value="") image_start = gr.Gallery(value=None) image_end = gr.Gallery(value=None) - video_source = gr.Video(value=None, visible=False) + video_source = gr.Video(value=None, visible=False) + any_video_source = False model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) - any_video_source = False with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") @@ -7343,7 +7446,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Column(): gr.Markdown("Upsampling - postprocessing that may improve fluidity and the size of the video") - def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grain_intensity, film_grain_saturation, element_class= None, max_height= None): + def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grain_intensity, film_grain_saturation, element_class= None, max_height= None, image_outputs = False): temporal_upsampling = gr.Dropdown( choices=[ ("Disabled", ""), @@ -7351,7 +7454,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Rife x4 frames/s", "rife4"), ], value=temporal_upsampling, - visible=True, + visible=not image_outputs, scale = 1, label="Temporal Upsampling", elem_classes= element_class @@ -7376,34 +7479,39 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non film_grain_saturation = gr.Slider(0.0, 1, value=film_grain_saturation, step=0.01, label="Film Grain Saturation") return temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation - temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", ""), ui_defaults.get("film_grain_intensity", 0), ui_defaults.get("film_grain_saturation", 0.5)) + temporal_upsampling, spatial_upsampling, film_grain_intensity, film_grain_saturation = gen_upsampling_dropdowns(ui_defaults.get("temporal_upsampling", ""), ui_defaults.get("spatial_upsampling", ""), ui_defaults.get("film_grain_intensity", 0), ui_defaults.get("film_grain_saturation", 0.5), image_outputs= image_outputs) - with gr.Tab("MMAudio", visible = server_config.get("mmaudio_enabled", 0) != 0 and not any_audio_track(base_model_type) and not image_outputs) as mmaudio_tab: - with gr.Column(): + with gr.Tab("Audio", visible = not image_outputs) as audio_tab: + with gr.Column(visible = server_config.get("mmaudio_enabled", 0) != 0) as mmaudio_col: gr.Markdown("Add a soundtrack based on the content of the Generated Video") - def gen_mmaudio_dropdowns(MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, MMAudio_seed = None, element_class = None, max_height = None ): - with gr.Row(max_height=max_height): - MMAudio_setting = gr.Dropdown( - choices=[ - ("Disabled", 0), - ("Enabled", 1), - ], - value=MMAudio_setting, - visible=True, - scale = 1, - label="MMAudio", - elem_classes= element_class, - # max_height = max_height - ) - if MMAudio_seed != None: - MMAudio_seed = gr.Slider(-1, 999999999, value=MMAudio_seed, step=1, scale=3, label="Seed (-1 for random)") - with gr.Row(max_height=max_height): - MMAudio_prompt = gr.Text(MMAudio_prompt, label="Prompt (1 or 2 keywords)", elem_classes= element_class) - MMAudio_neg_prompt = gr.Text(MMAudio_neg_prompt, label="Negative Prompt (1 or 2 keywords)", elem_classes= element_class) + with gr.Row(): + MMAudio_setting = gr.Dropdown( + choices=[("Disabled", 0), ("Enabled", 1), ], + value=ui_defaults.get("MMAudio_setting", 0), visible=True, scale = 1, label="MMAudio", + ) + # if MMAudio_seed != None: + # MMAudio_seed = gr.Slider(-1, 999999999, value=MMAudio_seed, step=1, scale=3, label="Seed (-1 for random)") + with gr.Row(): + MMAudio_prompt = gr.Text(ui_defaults.get("MMAudio_prompt", ""), label="Prompt (1 or 2 keywords)") + MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") + - return MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, MMAudio_seed - MMAudio_setting, MMAudio_prompt, MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns(ui_defaults.get("MMAudio_setting", 0), ui_defaults.get("MMAudio_prompt", ""), ui_defaults.get("MMAudio_neg_prompt", "")) + with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: + gr.Markdown("You may transfer the exising audio tracks of a Control Video") + audio_prompt_type_remux = gr.Dropdown( + choices=[ + ("No Remux", ""), + ("Remux Audio Files from Control Video if any and if no MMAudio / Custom Soundtrack", "R"), + ], + value=filter_letters(audio_prompt_type_value, "R"), + label="Remux Audio Files", + visible = True + ) + with gr.Column(): + gr.Markdown("Add Custom Soundtrack to Video") + audio_source = gr.Audio(value= ui_defaults.get("audio_source", None), type="filepath", label="Soundtrack", show_download_button= True) + with gr.Tab("Quality", visible = not (ltxv and no_negative_prompt or flux)) as quality_tab: with gr.Column(visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_custom or hunyuan_video_avatar or ltxv) ) as skip_layer_guidance_row: @@ -7506,6 +7614,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) + video_prompt_type_alignment = gr.Dropdown( + choices=[ + ("Aligned to the beginning of the Source Video", ""), + ("Aligned to the beginning of the First Window of the new Video Sample", "T"), + ], + value=filter_letters(video_prompt_type_value, "T"), + label="Control Video / Control Audio temporal alignment when any Source Video", + visible = vace or ltxv or t2v + ) + multi_prompts_gen_type = gr.Dropdown( choices=[ ("Will create new generated Video", 0), @@ -7533,10 +7651,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non gr.Markdown("You can change the Default number of Frames Per Second of the output Video, in the absence of Control Video this may create unwanted slow down / acceleration") force_fps_choices = [(f"Model Default ({fps} fps)", "")] - - if vace or ltxv or t2v: + if any_control_video and (any_video_source or recammaster): + force_fps_choices += [("Auto fps: Source Video if any, or Control Video if any, or Model Default", "auto")] + elif any_control_video : + force_fps_choices += [("Auto fps: Control Video if any, or Model Default", "auto")] + elif any_control_video and (any_video_source or recammaster): + force_fps_choices += [("Auto fps: Source Video if any, or Model Default", "auto")] + if any_control_video: force_fps_choices += [("Control Video fps", "control")] - if vace or ltxv or diffusion_forcing or recammaster: + if any_video_source or recammaster: force_fps_choices += [("Source Video fps", "source")] force_fps_choices += [ ("16", "16"), @@ -7552,17 +7675,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label=f"Override Frames Per Second (model default={fps} fps)" ) - with gr.Column(visible = (t2v or vace) and not (multitalk or fantasy)) as audio_prompt_type_remux_row: - gr.Markdown("You may transfer the exising audio tracks of a Control Video") - audio_prompt_type_remux = gr.Dropdown( - choices=[ - ("No Remux", ""), - ("Remux Audio Files from Control Video if any", "R"), - ], - value=filter_letters(audio_prompt_type_value, "R"), - label="Remux Audio Files", - visible = True - ) with gr.Row(): @@ -7585,7 +7697,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non refresh_form_trigger = gr.Text(interactive= False, visible=False) fill_wizard_prompt_trigger = gr.Text(interactive= False, visible=False) - with gr.Accordion("Video Info and Late Post Processing", open=False) as video_info_accordion: + with gr.Accordion("Video Info and Late Post Processing & Audio Remuxing", open=False) as video_info_accordion: with gr.Tabs() as video_info_tabs: with gr.Tab("Information", id="video_info"): default_visibility = {} if update_form else {"visible" : False} @@ -7606,21 +7718,35 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): with gr.Column(): - PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation = gen_upsampling_dropdowns("", "", 0, 0.5, element_class ="postprocess") - with gr.Column(visible = server_config.get("mmaudio_enabled", 0) == 1) as PP_MMAudio_col: - PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, _ = gen_mmaudio_dropdowns( 0, "" , "", None, element_class ="postprocess" ) - PP_MMAudio_seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)") - PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate") + PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation = gen_upsampling_dropdowns("", "", 0, 0.5, element_class ="postprocess", image_outputs = False) with gr.Row(): video_info_postprocessing_btn = gr.Button("Apply Postprocessing", size ="sm", visible=True) video_info_eject_video2_btn = gr.Button("Eject Video", size ="sm", visible=True) + with gr.Tab("Audio Remuxing", id= "audio_remuxing", visible = True) as audio_remuxing_tab: + with gr.Group(elem_classes= "postprocess"): + with gr.Column(visible = server_config.get("mmaudio_enabled", 0) != 0) as PP_MMAudio_col: + with gr.Row(): + PP_MMAudio_setting = gr.Dropdown( + choices=[("Add Custom Audio Sountrack", 0), ("Use MMAudio to generate a Soundtrack based on the Video", 1), ], + value=0, visible=True, scale = 1, label="MMAudio", show_label= False, elem_classes= "postprocess", + ) + with gr.Column(visible = False) as PP_MMAudio_row: + with gr.Row(): + PP_MMAudio_prompt = gr.Text("", label="Prompt (1 or 2 keywords)", elem_classes= "postprocess") + PP_MMAudio_neg_prompt = gr.Text("", label="Negative Prompt (1 or 2 keywords)", elem_classes= "postprocess") + PP_MMAudio_seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)") + PP_repeat_generation = gr.Slider(1, 25.0, value=1, step=1, label="Number of Sample Videos to Generate") + with gr.Row(visible = True) as PP_custom_audio_row: + PP_custom_audio = gr.Audio(label = "Soundtrack", type="filepath", show_download_button= True,) + with gr.Row(): + video_info_remux_audio_btn = gr.Button("Remux Audio", size ="sm", visible=True) + video_info_eject_video3_btn = gr.Button("Eject Video", size ="sm", visible=True) with gr.Tab("Add Videos / Images", id= "video_add"): files_to_load = gr.Files(label= "Files to Load in Gallery", height=120) with gr.Row(): video_info_add_videos_btn = gr.Button("Add Videos / Images", size ="sm") if not update_form: - generate_btn = gr.Button("Generate") generate_trigger = gr.Text(visible = False) add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False) @@ -7665,13 +7791,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn") extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, - prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, mmaudio_tab, quality_tab, + prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, - video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column, + video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, + video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, + NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, + min_frames_if_references_col, video_prompt_type_alignment] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7692,6 +7820,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) + video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) @@ -7701,9 +7830,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab]) + gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) - + PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) def refresh_status_async(state, progress=gr.Progress()): gen = get_gen_info(state) gen["progress"] = progress @@ -7757,7 +7886,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_family, model_choice, refresh_form_trigger]) video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) - gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_video2_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) + gr.on(triggers=[video_info_eject_video_btn.click, video_info_eject_video2_btn.click, video_info_eject_video3_btn.click, video_info_eject_image_btn.click], fn=eject_video_from_gallery, inputs =[state, output, last_choice], outputs = [output, video_info, video_buttons_row] ) video_info_to_control_video_btn.click(fn=video_to_control_video, inputs =[state, output, last_choice], outputs = [video_guide] ) video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) @@ -7765,7 +7894,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] ) video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] ) video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) - video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) + video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) + video_info_remux_audio_btn.click(fn=remux_audio, inputs =[state, output, last_choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( @@ -7999,7 +8129,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col + video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, audio_tab, PP_MMAudio_col )