mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1068 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1068 lines
		
	
	
		
			45 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import time
 | 
						|
import random
 | 
						|
import functools
 | 
						|
from typing import List, Optional, Tuple, Union
 | 
						|
 | 
						|
from pathlib import Path
 | 
						|
from einops import rearrange
 | 
						|
import torch
 | 
						|
import torch.distributed as dist
 | 
						|
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
 | 
						|
from hyvideo.vae import load_vae
 | 
						|
from hyvideo.modules import load_model
 | 
						|
from hyvideo.text_encoder import TextEncoder
 | 
						|
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
 | 
						|
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new 
 | 
						|
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
 | 
						|
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
 | 
						|
from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline
 | 
						|
from PIL import Image
 | 
						|
import numpy as np
 | 
						|
import torchvision.transforms as transforms
 | 
						|
import cv2
 | 
						|
from wan.utils.utils import resize_lanczos, calculate_new_dimensions
 | 
						|
from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask
 | 
						|
from transformers import WhisperModel
 | 
						|
from transformers import AutoFeatureExtractor
 | 
						|
from hyvideo.data_kits.face_align import AlignImage
 | 
						|
import librosa
 | 
						|
 | 
						|
def get_audio_feature(feature_extractor, audio_path, duration):
 | 
						|
    audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000)
 | 
						|
    assert sampling_rate == 16000
 | 
						|
 | 
						|
    audio_features = []
 | 
						|
    window = 750*640
 | 
						|
    for i in range(0, len(audio_input), window):
 | 
						|
        audio_feature = feature_extractor(audio_input[i:i+window], 
 | 
						|
                                        sampling_rate=sampling_rate, 
 | 
						|
                                        return_tensors="pt", 
 | 
						|
                                        device="cuda"
 | 
						|
                                        ).input_features
 | 
						|
        audio_features.append(audio_feature)
 | 
						|
 | 
						|
    audio_features = torch.cat(audio_features, dim=-1)
 | 
						|
    return audio_features, len(audio_input) // 640
 | 
						|
 | 
						|
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
 | 
						|
    crop_h, crop_w = crop_img.shape[:2]
 | 
						|
    target_w, target_h = size
 | 
						|
    scale_h, scale_w = target_h / crop_h, target_w / crop_w
 | 
						|
    if scale_w > scale_h:
 | 
						|
        resize_h = int(target_h*resize_ratio)
 | 
						|
        resize_w = int(crop_w / crop_h * resize_h)
 | 
						|
    else:
 | 
						|
        resize_w = int(target_w*resize_ratio)
 | 
						|
        resize_h = int(crop_h / crop_w * resize_w)
 | 
						|
    crop_img = cv2.resize(crop_img, (resize_w, resize_h))
 | 
						|
    pad_left = (target_w - resize_w) // 2
 | 
						|
    pad_top = (target_h - resize_h) // 2
 | 
						|
    pad_right = target_w - resize_w - pad_left
 | 
						|
    pad_bottom = target_h - resize_h - pad_top
 | 
						|
    crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
 | 
						|
    return crop_img
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
 | 
						|
    num_images, num_image_patches, embed_dim = image_features.shape
 | 
						|
    batch_size, sequence_length = input_ids.shape
 | 
						|
    left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
 | 
						|
    # 1. Create a mask to know where special image tokens are
 | 
						|
    special_image_token_mask = input_ids == self.config.image_token_index
 | 
						|
    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
 | 
						|
    # Compute the maximum embed dimension
 | 
						|
    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
 | 
						|
    batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
 | 
						|
 | 
						|
    # 2. Compute the positions where text should be written
 | 
						|
    # Calculate new positions for text tokens in merged image-text sequence.
 | 
						|
    # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
 | 
						|
    # `torch.cumsum` computes how each image token shifts subsequent text token positions.
 | 
						|
    # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
 | 
						|
    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
 | 
						|
    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
 | 
						|
    if left_padding:
 | 
						|
        new_token_positions += nb_image_pad[:, None]  # offset for left padding
 | 
						|
    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
 | 
						|
 | 
						|
    # 3. Create the full embedding, already padded to the maximum position
 | 
						|
    final_embedding = torch.zeros(
 | 
						|
        batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
 | 
						|
    )
 | 
						|
    final_attention_mask = torch.zeros(
 | 
						|
        batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
 | 
						|
    )
 | 
						|
    if labels is not None:
 | 
						|
        final_labels = torch.full(
 | 
						|
            (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
 | 
						|
        )
 | 
						|
    # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
 | 
						|
    # set the corresponding tensors into their correct target device.
 | 
						|
    target_device = inputs_embeds.device
 | 
						|
    batch_indices, non_image_indices, text_to_overwrite = (
 | 
						|
        batch_indices.to(target_device),
 | 
						|
        non_image_indices.to(target_device),
 | 
						|
        text_to_overwrite.to(target_device),
 | 
						|
    )
 | 
						|
    attention_mask = attention_mask.to(target_device)
 | 
						|
 | 
						|
    # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
 | 
						|
    # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
 | 
						|
    final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
 | 
						|
    final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
 | 
						|
    if labels is not None:
 | 
						|
        final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
 | 
						|
 | 
						|
    # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
 | 
						|
    image_to_overwrite = torch.full(
 | 
						|
        (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
 | 
						|
    )
 | 
						|
    image_to_overwrite[batch_indices, text_to_overwrite] = False
 | 
						|
    image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
 | 
						|
 | 
						|
    if image_to_overwrite.sum() != image_features.shape[:-1].numel():
 | 
						|
        raise ValueError(
 | 
						|
            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
 | 
						|
            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
 | 
						|
        )
 | 
						|
 | 
						|
    final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
 | 
						|
    final_attention_mask |= image_to_overwrite
 | 
						|
    position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
 | 
						|
 | 
						|
    # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
 | 
						|
    batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
 | 
						|
    indices_to_mask = new_token_positions[batch_indices, pad_indices]
 | 
						|
 | 
						|
    final_embedding[batch_indices, indices_to_mask] = 0
 | 
						|
 | 
						|
    if labels is None:
 | 
						|
        final_labels = None
 | 
						|
 | 
						|
    return final_embedding, final_attention_mask, final_labels, position_ids
 | 
						|
    
 | 
						|
def patched_llava_forward(
 | 
						|
    self,
 | 
						|
    input_ids: torch.LongTensor = None,
 | 
						|
    pixel_values: torch.FloatTensor = None,
 | 
						|
    attention_mask: Optional[torch.Tensor] = None,
 | 
						|
    position_ids: Optional[torch.LongTensor] = None,
 | 
						|
    past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
						|
    inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
						|
    vision_feature_layer: Optional[int] = None,
 | 
						|
    vision_feature_select_strategy: Optional[str] = None,
 | 
						|
    labels: Optional[torch.LongTensor] = None,
 | 
						|
    use_cache: Optional[bool] = None,
 | 
						|
    output_attentions: Optional[bool] = None,
 | 
						|
    output_hidden_states: Optional[bool] = None,
 | 
						|
    return_dict: Optional[bool] = None,
 | 
						|
    cache_position: Optional[torch.LongTensor] = None,
 | 
						|
    num_logits_to_keep: int = 0,
 | 
						|
):
 | 
						|
    from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
 | 
						|
 | 
						|
 | 
						|
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
						|
    output_hidden_states = (
 | 
						|
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
						|
    )
 | 
						|
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
						|
    vision_feature_layer = (
 | 
						|
        vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
 | 
						|
    )
 | 
						|
    vision_feature_select_strategy = (
 | 
						|
        vision_feature_select_strategy
 | 
						|
        if vision_feature_select_strategy is not None
 | 
						|
        else self.config.vision_feature_select_strategy
 | 
						|
    )
 | 
						|
 | 
						|
    if (input_ids is None) ^ (inputs_embeds is not None):
 | 
						|
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
 | 
						|
 | 
						|
    if pixel_values is not None and inputs_embeds is not None:
 | 
						|
        raise ValueError(
 | 
						|
            "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
 | 
						|
        )
 | 
						|
 | 
						|
    if inputs_embeds is None:
 | 
						|
        inputs_embeds = self.get_input_embeddings()(input_ids)
 | 
						|
 | 
						|
    image_features = None
 | 
						|
    if pixel_values is not None:
 | 
						|
        image_features = self.get_image_features(
 | 
						|
            pixel_values=pixel_values,
 | 
						|
            vision_feature_layer=vision_feature_layer,
 | 
						|
            vision_feature_select_strategy=vision_feature_select_strategy,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
    inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
 | 
						|
        image_features, inputs_embeds, input_ids, attention_mask, labels
 | 
						|
    )
 | 
						|
    cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
 | 
						|
 | 
						|
 | 
						|
    outputs = self.language_model(
 | 
						|
        attention_mask=attention_mask,
 | 
						|
        position_ids=position_ids,
 | 
						|
        past_key_values=past_key_values,
 | 
						|
        inputs_embeds=inputs_embeds,
 | 
						|
        use_cache=use_cache,
 | 
						|
        output_attentions=output_attentions,
 | 
						|
        output_hidden_states=output_hidden_states,
 | 
						|
        return_dict=return_dict,
 | 
						|
        cache_position=cache_position,
 | 
						|
        num_logits_to_keep=num_logits_to_keep,
 | 
						|
    )
 | 
						|
 | 
						|
    logits = outputs[0]
 | 
						|
 | 
						|
    loss = None
 | 
						|
 | 
						|
    if not return_dict:
 | 
						|
        output = (logits,) + outputs[1:]
 | 
						|
        return (loss,) + output if loss is not None else output
 | 
						|
 | 
						|
    return LlavaCausalLMOutputWithPast(
 | 
						|
        loss=loss,
 | 
						|
        logits=logits,
 | 
						|
        past_key_values=outputs.past_key_values,
 | 
						|
        hidden_states=outputs.hidden_states,
 | 
						|
        attentions=outputs.attentions,
 | 
						|
        image_hidden_states=image_features if pixel_values is not None else None,
 | 
						|
    )
 | 
						|
 | 
						|
def adapt_model(model, audio_block_name):
 | 
						|
    modules_dict= { k: m for k, m in model.named_modules()}
 | 
						|
    for model_layer, avatar_layer in model.double_stream_map.items():
 | 
						|
        module = modules_dict[f"{audio_block_name}.{avatar_layer}"]
 | 
						|
        target = modules_dict[f"double_blocks.{model_layer}"]
 | 
						|
        setattr(target, "audio_adapter", module )
 | 
						|
    delattr(model, audio_block_name)
 | 
						|
 | 
						|
class DataPreprocess(object):
 | 
						|
    def __init__(self):
 | 
						|
        self.llava_size = (336, 336)
 | 
						|
        self.llava_transform = transforms.Compose(
 | 
						|
            [
 | 
						|
                transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), 
 | 
						|
                transforms.ToTensor(), 
 | 
						|
                transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
    def get_batch(self, image , size, pad = False):
 | 
						|
        image = np.asarray(image)
 | 
						|
        if pad:
 | 
						|
            llava_item_image = pad_image(image.copy(), self.llava_size)
 | 
						|
        else:
 | 
						|
            llava_item_image = image.copy()
 | 
						|
        uncond_llava_item_image = np.ones_like(llava_item_image) * 255
 | 
						|
 | 
						|
        if pad:
 | 
						|
            cat_item_image = pad_image(image.copy(), size)
 | 
						|
        else:
 | 
						|
            cat_item_image = image.copy()
 | 
						|
        llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
 | 
						|
        uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
 | 
						|
        cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
 | 
						|
        # batch = {
 | 
						|
        #     "pixel_value_llava": llava_item_tensor.unsqueeze(0),
 | 
						|
        #     "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
 | 
						|
        #     'pixel_value_ref': cat_item_tensor.unsqueeze(0), 
 | 
						|
        # }
 | 
						|
        return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0)
 | 
						|
 | 
						|
class Inference(object):
 | 
						|
    def __init__(        
 | 
						|
        self,
 | 
						|
        i2v,
 | 
						|
        custom,
 | 
						|
        avatar,
 | 
						|
        enable_cfg,
 | 
						|
        vae,
 | 
						|
        vae_kwargs,
 | 
						|
        text_encoder,
 | 
						|
        model,
 | 
						|
        text_encoder_2=None,
 | 
						|
        pipeline=None,
 | 
						|
        feature_extractor=None,
 | 
						|
        wav2vec=None,
 | 
						|
        align_instance=None,
 | 
						|
        device=None,
 | 
						|
    ):
 | 
						|
        self.i2v = i2v
 | 
						|
        self.custom = custom
 | 
						|
        self.avatar = avatar
 | 
						|
        self.enable_cfg = enable_cfg
 | 
						|
        self.vae = vae
 | 
						|
        self.vae_kwargs = vae_kwargs
 | 
						|
 | 
						|
        self.text_encoder = text_encoder
 | 
						|
        self.text_encoder_2 = text_encoder_2
 | 
						|
 | 
						|
        self.model = model
 | 
						|
        self.pipeline = pipeline
 | 
						|
 | 
						|
        self.feature_extractor=feature_extractor
 | 
						|
        self.wav2vec=wav2vec
 | 
						|
        self.align_instance=align_instance
 | 
						|
 | 
						|
        self.device = "cuda"
 | 
						|
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_pretrained(cls, model_filepath, base_model_type, text_encoder_filepath,  dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs):
 | 
						|
 | 
						|
        device = "cuda" 
 | 
						|
 | 
						|
        import transformers
 | 
						|
        transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
 | 
						|
        transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
 | 
						|
 | 
						|
        torch.set_grad_enabled(False)
 | 
						|
        text_len = 512
 | 
						|
        latent_channels = 16
 | 
						|
        precision = "bf16"
 | 
						|
        vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16" 
 | 
						|
        embedded_cfg_scale = 6
 | 
						|
        filepath = model_filepath[0]
 | 
						|
        i2v_condition_type = None
 | 
						|
        i2v_mode = False
 | 
						|
        custom = False
 | 
						|
        custom_audio = False
 | 
						|
        avatar = False 
 | 
						|
        if base_model_type == "hunyuan_i2v":
 | 
						|
            model_id = "HYVideo-T/2"
 | 
						|
            i2v_condition_type = "token_replace"
 | 
						|
            i2v_mode = True
 | 
						|
        elif base_model_type == "hunyuan_custom":
 | 
						|
            model_id = "HYVideo-T/2-custom"
 | 
						|
            custom = True
 | 
						|
        elif base_model_type == "hunyuan_custom_audio":
 | 
						|
            model_id = "HYVideo-T/2-custom-audio"
 | 
						|
            custom_audio = True
 | 
						|
            custom = True
 | 
						|
        elif base_model_type == "hunyuan_custom_edit":
 | 
						|
            model_id = "HYVideo-T/2-custom-edit"
 | 
						|
            custom = True
 | 
						|
        elif base_model_type == "hunyuan_avatar":
 | 
						|
            model_id = "HYVideo-T/2-avatar"
 | 
						|
            text_len = 256
 | 
						|
            avatar = True
 | 
						|
        else:
 | 
						|
            model_id = "HYVideo-T/2-cfgdistill"
 | 
						|
 | 
						|
 | 
						|
        if i2v_mode and i2v_condition_type == "latent_concat":
 | 
						|
            in_channels = latent_channels * 2 + 1
 | 
						|
            image_embed_interleave = 2
 | 
						|
        elif i2v_mode and i2v_condition_type == "token_replace":
 | 
						|
            in_channels = latent_channels
 | 
						|
            image_embed_interleave = 4
 | 
						|
        else:
 | 
						|
            in_channels = latent_channels
 | 
						|
            image_embed_interleave = 1
 | 
						|
        out_channels = latent_channels
 | 
						|
        pinToMemory = kwargs.pop("pinToMemory", False)
 | 
						|
        partialPinning = kwargs.pop("partialPinning", False)        
 | 
						|
        factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]}
 | 
						|
 | 
						|
        if embedded_cfg_scale and i2v_mode:
 | 
						|
            factor_kwargs["guidance_embed"] = True
 | 
						|
 | 
						|
        model = load_model(
 | 
						|
            model = model_id,
 | 
						|
            i2v_condition_type = i2v_condition_type,
 | 
						|
            in_channels=in_channels,
 | 
						|
            out_channels=out_channels,
 | 
						|
            factor_kwargs=factor_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
  
 | 
						|
        from mmgp import offload
 | 
						|
        # model = Inference.load_state_dict(args, model, model_filepath)
 | 
						|
 | 
						|
        # model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt"
 | 
						|
        offload.load_model_data(model, model_filepath, quantizeTransformer = quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning)
 | 
						|
        pass
 | 
						|
        # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors")
 | 
						|
        # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True)
 | 
						|
        if save_quantized:            
 | 
						|
            from wan.utils.utils import save_quantized_model
 | 
						|
            save_quantized_model(model, filepath, dtype, None)
 | 
						|
            
 | 
						|
        model.mixed_precision = mixed_precision_transformer
 | 
						|
 | 
						|
        if model.mixed_precision :
 | 
						|
            model._lock_dtype = torch.float32
 | 
						|
            model.lock_layers_dtypes(torch.float32)
 | 
						|
        model.eval()
 | 
						|
 | 
						|
        # ============================= Build extra models ========================
 | 
						|
        # VAE
 | 
						|
        if custom or avatar:
 | 
						|
            vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json"
 | 
						|
            vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors"
 | 
						|
        # elif avatar:
 | 
						|
        #     vae_configpath = "ckpts/config_vae_avatar.json"
 | 
						|
        #     vae_filepath = "ckpts/vae_avatar.pt"
 | 
						|
        else:
 | 
						|
            vae_configpath = "ckpts/hunyuan_video_VAE_config.json"
 | 
						|
            vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors"
 | 
						|
 | 
						|
    # config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json")
 | 
						|
    # config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json")
 | 
						|
 | 
						|
        vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", )
 | 
						|
 | 
						|
        vae._model_dtype =  torch.float32 if VAE_dtype == torch.float32 else  (torch.float16 if avatar else torch.bfloat16)
 | 
						|
        vae._model_dtype =  torch.float32 if VAE_dtype == torch.float32 else  torch.bfloat16
 | 
						|
        vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
 | 
						|
        enable_cfg = False
 | 
						|
        # Text encoder
 | 
						|
        if i2v_mode:
 | 
						|
            text_encoder = "llm-i2v"
 | 
						|
            tokenizer = "llm-i2v"
 | 
						|
            prompt_template = "dit-llm-encode-i2v"
 | 
						|
            prompt_template_video = "dit-llm-encode-video-i2v"
 | 
						|
        elif custom or avatar :
 | 
						|
            text_encoder = "llm-i2v"
 | 
						|
            tokenizer = "llm-i2v"
 | 
						|
            prompt_template = "dit-llm-encode"
 | 
						|
            prompt_template_video = "dit-llm-encode-video"
 | 
						|
            enable_cfg = True
 | 
						|
        else:
 | 
						|
            text_encoder = "llm"
 | 
						|
            tokenizer = "llm"
 | 
						|
            prompt_template = "dit-llm-encode"
 | 
						|
            prompt_template_video = "dit-llm-encode-video"
 | 
						|
 | 
						|
        if prompt_template_video is not None:
 | 
						|
            crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 )
 | 
						|
        elif prompt_template is not None:
 | 
						|
            crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0)
 | 
						|
        else:
 | 
						|
            crop_start = 0
 | 
						|
        max_length = text_len + crop_start
 | 
						|
 | 
						|
        # prompt_template
 | 
						|
        prompt_template =  PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None
 | 
						|
 | 
						|
        # prompt_template_video
 | 
						|
        prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None
 | 
						|
        
 | 
						|
 | 
						|
        text_encoder = TextEncoder(
 | 
						|
            text_encoder_type=text_encoder,
 | 
						|
            max_length=max_length,
 | 
						|
            text_encoder_precision="fp16",
 | 
						|
            tokenizer_type=tokenizer,
 | 
						|
            i2v_mode=i2v_mode,
 | 
						|
            prompt_template=prompt_template,
 | 
						|
            prompt_template_video=prompt_template_video,
 | 
						|
            hidden_state_skip_layer=2,
 | 
						|
            apply_final_norm=False,
 | 
						|
            reproduce=True,
 | 
						|
            device="cpu",
 | 
						|
            image_embed_interleave=image_embed_interleave,
 | 
						|
   			text_encoder_path = text_encoder_filepath            
 | 
						|
        )
 | 
						|
 | 
						|
        text_encoder_2 = TextEncoder(
 | 
						|
            text_encoder_type="clipL",
 | 
						|
            max_length=77,
 | 
						|
            text_encoder_precision="fp16",
 | 
						|
            tokenizer_type="clipL",
 | 
						|
            reproduce=True,
 | 
						|
            device="cpu",
 | 
						|
        )
 | 
						|
 | 
						|
        feature_extractor = None
 | 
						|
        wav2vec = None
 | 
						|
        align_instance = None
 | 
						|
 | 
						|
        if avatar or custom_audio:
 | 
						|
            feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/")
 | 
						|
            wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32)
 | 
						|
            wav2vec._model_dtype = torch.float32
 | 
						|
            wav2vec.requires_grad_(False)
 | 
						|
        if avatar:
 | 
						|
            align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
 | 
						|
            align_instance.facedet.model.to("cpu")
 | 
						|
            adapt_model(model, "audio_adapter_blocks")
 | 
						|
        elif custom_audio:
 | 
						|
            adapt_model(model, "audio_models")
 | 
						|
 | 
						|
        return cls(
 | 
						|
            i2v=i2v_mode,
 | 
						|
            custom=custom,
 | 
						|
            avatar=avatar,
 | 
						|
            enable_cfg = enable_cfg,
 | 
						|
            vae=vae,
 | 
						|
            vae_kwargs=vae_kwargs,
 | 
						|
            text_encoder=text_encoder,
 | 
						|
            text_encoder_2=text_encoder_2,
 | 
						|
            model=model,
 | 
						|
            feature_extractor=feature_extractor,
 | 
						|
            wav2vec=wav2vec,
 | 
						|
            align_instance=align_instance,
 | 
						|
            device=device,
 | 
						|
        )
 | 
						|
 | 
						|
  
 | 
						|
 | 
						|
class HunyuanVideoSampler(Inference):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        i2v,
 | 
						|
        custom,
 | 
						|
        avatar,
 | 
						|
        enable_cfg,
 | 
						|
        vae,
 | 
						|
        vae_kwargs,
 | 
						|
        text_encoder,
 | 
						|
        model,
 | 
						|
        text_encoder_2=None,
 | 
						|
        pipeline=None,
 | 
						|
        feature_extractor=None,
 | 
						|
        wav2vec=None,
 | 
						|
        align_instance=None,
 | 
						|
        device=0,
 | 
						|
    ):
 | 
						|
        super().__init__(
 | 
						|
            i2v,
 | 
						|
            custom,
 | 
						|
            avatar,
 | 
						|
            enable_cfg,
 | 
						|
            vae,
 | 
						|
            vae_kwargs,
 | 
						|
            text_encoder,
 | 
						|
            model,
 | 
						|
            text_encoder_2=text_encoder_2,
 | 
						|
            pipeline=pipeline,
 | 
						|
            feature_extractor=feature_extractor,
 | 
						|
            wav2vec=wav2vec,
 | 
						|
            align_instance=align_instance,
 | 
						|
            device=device,
 | 
						|
        )
 | 
						|
 | 
						|
        self.i2v_mode = i2v
 | 
						|
        self.enable_cfg = enable_cfg
 | 
						|
        self.pipeline = self.load_diffusion_pipeline(
 | 
						|
            avatar = self.avatar,
 | 
						|
            vae=self.vae,
 | 
						|
            text_encoder=self.text_encoder,
 | 
						|
            text_encoder_2=self.text_encoder_2,
 | 
						|
            model=self.model,
 | 
						|
            device=self.device,
 | 
						|
        )
 | 
						|
 | 
						|
        if self.i2v_mode:
 | 
						|
            self.default_negative_prompt = NEGATIVE_PROMPT_I2V
 | 
						|
        else:
 | 
						|
            self.default_negative_prompt = NEGATIVE_PROMPT
 | 
						|
 | 
						|
    @property
 | 
						|
    def _interrupt(self):
 | 
						|
        return self.pipeline._interrupt
 | 
						|
 | 
						|
    @_interrupt.setter
 | 
						|
    def _interrupt(self, value):
 | 
						|
        self.pipeline._interrupt =value 
 | 
						|
 | 
						|
    def load_diffusion_pipeline(
 | 
						|
        self,
 | 
						|
        avatar,
 | 
						|
        vae,
 | 
						|
        text_encoder,
 | 
						|
        text_encoder_2,
 | 
						|
        model,
 | 
						|
        scheduler=None,
 | 
						|
        device=None,
 | 
						|
        progress_bar_config=None,
 | 
						|
        #data_type="video",
 | 
						|
    ):
 | 
						|
        """Load the denoising scheduler for inference."""
 | 
						|
        if scheduler is None:
 | 
						|
            scheduler = FlowMatchDiscreteScheduler(
 | 
						|
                shift=6.0,
 | 
						|
                reverse=True,
 | 
						|
                solver="euler",
 | 
						|
            )
 | 
						|
 | 
						|
        if avatar:
 | 
						|
            pipeline = HunyuanVideoAudioPipeline(
 | 
						|
                vae=vae,
 | 
						|
                text_encoder=text_encoder,
 | 
						|
                text_encoder_2=text_encoder_2,
 | 
						|
                transformer=model,
 | 
						|
                scheduler=scheduler,
 | 
						|
                progress_bar_config=progress_bar_config,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            pipeline = HunyuanVideoPipeline(
 | 
						|
                vae=vae,
 | 
						|
                text_encoder=text_encoder,
 | 
						|
                text_encoder_2=text_encoder_2,
 | 
						|
                transformer=model,
 | 
						|
                scheduler=scheduler,
 | 
						|
                progress_bar_config=progress_bar_config,
 | 
						|
            )
 | 
						|
 
 | 
						|
        return pipeline
 | 
						|
 | 
						|
    def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}, enable_riflex = False):
 | 
						|
        target_ndim = 3
 | 
						|
        ndim = 5 - 2
 | 
						|
        latents_size = [(video_length-1)//4+1 , height//8, width//8]
 | 
						|
 | 
						|
        if isinstance(self.model.patch_size, int):
 | 
						|
            assert all(s % self.model.patch_size == 0 for s in latents_size), \
 | 
						|
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
 | 
						|
                f"but got {latents_size}."
 | 
						|
            rope_sizes = [s // self.model.patch_size for s in latents_size]
 | 
						|
        elif isinstance(self.model.patch_size, list):
 | 
						|
            assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
 | 
						|
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
 | 
						|
                f"but got {latents_size}."
 | 
						|
            rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
 | 
						|
 | 
						|
        if len(rope_sizes) != target_ndim:
 | 
						|
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
 | 
						|
        head_dim = self.model.hidden_size // self.model.heads_num
 | 
						|
        rope_dim_list = self.model.rope_dim_list
 | 
						|
        if rope_dim_list is None:
 | 
						|
            rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
 | 
						|
        assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
 | 
						|
        freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, 
 | 
						|
                                                    rope_sizes, 
 | 
						|
                                                    theta=256, 
 | 
						|
                                                    use_real=True,
 | 
						|
                                                    theta_rescale_factor=1,
 | 
						|
                                                    concat_dict=concat_dict,
 | 
						|
                                                    L_test = (video_length - 1) // 4 + 1,
 | 
						|
                                                    enable_riflex = enable_riflex
 | 
						|
                                                    )
 | 
						|
        return freqs_cos, freqs_sin
 | 
						|
        
 | 
						|
    def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False):
 | 
						|
        target_ndim = 3
 | 
						|
        ndim = 5 - 2
 | 
						|
        # 884
 | 
						|
        vae = "884-16c-hy"
 | 
						|
        if "884" in vae:
 | 
						|
            latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
 | 
						|
        elif "888" in vae:
 | 
						|
            latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
 | 
						|
        else:
 | 
						|
            latents_size = [video_length, height // 8, width // 8]
 | 
						|
 | 
						|
        if isinstance(self.model.patch_size, int):
 | 
						|
            assert all(s % self.model.patch_size == 0 for s in latents_size), (
 | 
						|
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
 | 
						|
                f"but got {latents_size}."
 | 
						|
            )
 | 
						|
            rope_sizes = [s // self.model.patch_size for s in latents_size]
 | 
						|
        elif isinstance(self.model.patch_size, list):
 | 
						|
            assert all(
 | 
						|
                s % self.model.patch_size[idx] == 0
 | 
						|
                for idx, s in enumerate(latents_size)
 | 
						|
            ), (
 | 
						|
                f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
 | 
						|
                f"but got {latents_size}."
 | 
						|
            )
 | 
						|
            rope_sizes = [
 | 
						|
                s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
 | 
						|
            ]
 | 
						|
 | 
						|
        if len(rope_sizes) != target_ndim:
 | 
						|
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
 | 
						|
        head_dim = self.model.hidden_size // self.model.heads_num
 | 
						|
        rope_dim_list = self.model.rope_dim_list
 | 
						|
        if rope_dim_list is None:
 | 
						|
            rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
 | 
						|
        assert (
 | 
						|
            sum(rope_dim_list) == head_dim
 | 
						|
        ), "sum(rope_dim_list) should equal to head_dim of attention layer"
 | 
						|
        freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
 | 
						|
            rope_dim_list,
 | 
						|
            rope_sizes,
 | 
						|
            theta=256,
 | 
						|
            use_real=True,
 | 
						|
            theta_rescale_factor=1,
 | 
						|
            L_test = (video_length - 1) // 4 + 1,
 | 
						|
            enable_riflex = enable_riflex
 | 
						|
        )
 | 
						|
        return freqs_cos, freqs_sin
 | 
						|
 | 
						|
 | 
						|
    def generate(
 | 
						|
        self,
 | 
						|
        input_prompt,
 | 
						|
        input_ref_images = None,
 | 
						|
        audio_guide = None,
 | 
						|
        input_frames = None,
 | 
						|
        input_masks = None,
 | 
						|
        input_video = None,        
 | 
						|
        fps = 24,
 | 
						|
        height=192,
 | 
						|
        width=336,
 | 
						|
        frame_num=129,
 | 
						|
        seed=None,
 | 
						|
        n_prompt=None,
 | 
						|
        sampling_steps=50,
 | 
						|
        guide_scale=1.0,
 | 
						|
        shift=5.0,
 | 
						|
        embedded_guidance_scale=6.0,
 | 
						|
        batch_size=1,
 | 
						|
        num_videos_per_prompt=1,
 | 
						|
        i2v_resolution="720p",
 | 
						|
        image_start=None,
 | 
						|
        enable_RIFLEx = False,
 | 
						|
        i2v_condition_type: str = "token_replace",
 | 
						|
        i2v_stability=True,
 | 
						|
        VAE_tile_size = None,
 | 
						|
        joint_pass = False,
 | 
						|
        cfg_star_switch = False,
 | 
						|
        fit_into_canvas = True,
 | 
						|
        conditioning_latents_size = 0,
 | 
						|
        **kwargs,
 | 
						|
    ):
 | 
						|
 | 
						|
        if VAE_tile_size != None:
 | 
						|
            self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"]
 | 
						|
            self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"]
 | 
						|
            self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"]
 | 
						|
            self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"]
 | 
						|
            self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"]
 | 
						|
            self.vae.enable_tiling()
 | 
						|
 | 
						|
        i2v_mode= self.i2v_mode
 | 
						|
        if not self.enable_cfg:
 | 
						|
            guide_scale=1.0
 | 
						|
 | 
						|
        # ========================================================================
 | 
						|
        # Arguments: seed
 | 
						|
        # ========================================================================
 | 
						|
        if isinstance(seed, torch.Tensor):
 | 
						|
            seed = seed.tolist()
 | 
						|
        if seed is None:
 | 
						|
            seeds = [
 | 
						|
                random.randint(0, 1_000_000)
 | 
						|
                for _ in range(batch_size * num_videos_per_prompt)
 | 
						|
            ]
 | 
						|
        elif isinstance(seed, int):
 | 
						|
            seeds = [
 | 
						|
                seed + i
 | 
						|
                for _ in range(batch_size)
 | 
						|
                for i in range(num_videos_per_prompt)
 | 
						|
            ]
 | 
						|
        elif isinstance(seed, (list, tuple)):
 | 
						|
            if len(seed) == batch_size:
 | 
						|
                seeds = [
 | 
						|
                    int(seed[i]) + j
 | 
						|
                    for i in range(batch_size)
 | 
						|
                    for j in range(num_videos_per_prompt)
 | 
						|
                ]
 | 
						|
            elif len(seed) == batch_size * num_videos_per_prompt:
 | 
						|
                seeds = [int(s) for s in seed]
 | 
						|
            else:
 | 
						|
                raise ValueError(
 | 
						|
                    f"Length of seed must be equal to number of prompt(batch_size) or "
 | 
						|
                    f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            raise ValueError(
 | 
						|
                f"Seed must be an integer, a list of integers, or None, got {seed}."
 | 
						|
            )
 | 
						|
        from wan.utils.utils import seed_everything
 | 
						|
        seed_everything(seed)
 | 
						|
        generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
 | 
						|
        # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
 | 
						|
 | 
						|
        # ========================================================================
 | 
						|
        # Arguments: target_width, target_height, target_frame_num
 | 
						|
        # ========================================================================
 | 
						|
        if width <= 0 or height <= 0 or frame_num <= 0:
 | 
						|
            raise ValueError(
 | 
						|
                f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}"
 | 
						|
            )
 | 
						|
        if (frame_num - 1) % 4 != 0:
 | 
						|
            raise ValueError(
 | 
						|
                f"`frame_num-1` must be a multiple of 4, got {frame_num}"
 | 
						|
            )
 | 
						|
 | 
						|
        target_height = align_to(height, 16)
 | 
						|
        target_width = align_to(width, 16)
 | 
						|
        target_frame_num = frame_num
 | 
						|
        audio_strength = 1
 | 
						|
 | 
						|
        if input_ref_images  != None:
 | 
						|
            # ip_cfg_scale = 3.0
 | 
						|
            ip_cfg_scale = 0
 | 
						|
            denoise_strength = 1
 | 
						|
            # guide_scale=7.5
 | 
						|
            # shift=13
 | 
						|
            name = "person"
 | 
						|
            input_ref_images = input_ref_images[0]
 | 
						|
 | 
						|
        # ========================================================================
 | 
						|
        # Arguments: prompt, new_prompt, negative_prompt
 | 
						|
        # ========================================================================
 | 
						|
        if not isinstance(input_prompt, str):
 | 
						|
            raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}")
 | 
						|
        input_prompt = [input_prompt.strip()]
 | 
						|
 | 
						|
        # negative prompt
 | 
						|
        if n_prompt is None or n_prompt == "":
 | 
						|
            n_prompt = self.default_negative_prompt
 | 
						|
        if guide_scale == 1.0:
 | 
						|
            n_prompt = ""
 | 
						|
        if not isinstance(n_prompt, str):
 | 
						|
            raise TypeError(
 | 
						|
                f"`negative_prompt` must be a string, but got {type(n_prompt)}"
 | 
						|
            )
 | 
						|
        n_prompt = [n_prompt.strip()]
 | 
						|
 | 
						|
        # ========================================================================
 | 
						|
        # Scheduler
 | 
						|
        # ========================================================================
 | 
						|
        scheduler = FlowMatchDiscreteScheduler(
 | 
						|
            shift=shift,
 | 
						|
            reverse=True,
 | 
						|
            solver="euler"
 | 
						|
        )
 | 
						|
        self.pipeline.scheduler = scheduler
 | 
						|
 | 
						|
        # ---------------------------------
 | 
						|
        # Reference condition
 | 
						|
        # ---------------------------------
 | 
						|
        img_latents = None
 | 
						|
        semantic_images = None
 | 
						|
        denoise_strength = 0
 | 
						|
        ip_cfg_scale = 0
 | 
						|
        if i2v_mode:
 | 
						|
            if i2v_resolution == "720p":
 | 
						|
                bucket_hw_base_size = 960
 | 
						|
            elif i2v_resolution == "540p":
 | 
						|
                bucket_hw_base_size = 720
 | 
						|
            elif i2v_resolution == "360p":
 | 
						|
                bucket_hw_base_size = 480
 | 
						|
            else:
 | 
						|
                raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
 | 
						|
 | 
						|
            # semantic_images = [Image.open(i2v_image_path).convert('RGB')]
 | 
						|
            semantic_images = [image_start.convert('RGB')] #
 | 
						|
            origin_size = semantic_images[0].size
 | 
						|
            h, w = origin_size
 | 
						|
            h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
						|
            closest_size = (w, h)
 | 
						|
            # crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
 | 
						|
            # aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list])
 | 
						|
            # closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
 | 
						|
            ref_image_transform = transforms.Compose([
 | 
						|
                transforms.Resize(closest_size),
 | 
						|
                transforms.CenterCrop(closest_size),
 | 
						|
                transforms.ToTensor(),
 | 
						|
                transforms.Normalize([0.5], [0.5])
 | 
						|
            ])
 | 
						|
 | 
						|
            semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
 | 
						|
            semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
 | 
						|
 | 
						|
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
 | 
						|
                img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W
 | 
						|
                img_latents.mul_(self.pipeline.vae.config.scaling_factor)
 | 
						|
 | 
						|
            target_height, target_width = closest_size
 | 
						|
 | 
						|
        # ========================================================================
 | 
						|
        # Build Rope freqs
 | 
						|
        # ========================================================================
 | 
						|
 | 
						|
        if input_ref_images == None:
 | 
						|
            freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
 | 
						|
        else:
 | 
						|
            if self.avatar:
 | 
						|
                w, h = input_ref_images.size
 | 
						|
                target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
 | 
						|
                if target_width != w or target_height != h:
 | 
						|
                    input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) 
 | 
						|
 | 
						|
                concat_dict = {'mode': 'timecat', 'bias': -1} 
 | 
						|
                freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
 | 
						|
            else:
 | 
						|
                if input_frames != None:
 | 
						|
                    target_height, target_width = input_frames.shape[-3:-1]
 | 
						|
                elif input_video != None:
 | 
						|
                    target_height, target_width = input_video.shape[-2:]
 | 
						|
 | 
						|
                concat_dict = {'mode': 'timecat-w', 'bias': -1} 
 | 
						|
                freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict, enable_RIFLEx)
 | 
						|
 | 
						|
        n_tokens = freqs_cos.shape[0]
 | 
						|
 | 
						|
        callback = kwargs.pop("callback", None)
 | 
						|
        callback_steps = kwargs.pop("callback_steps", None)
 | 
						|
        # ========================================================================
 | 
						|
        # Pipeline inference
 | 
						|
        # ========================================================================
 | 
						|
 | 
						|
        pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref =  None, None, None
 | 
						|
        if input_ref_images  == None:
 | 
						|
            name = None
 | 
						|
        else:
 | 
						|
            pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref =  DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom)
 | 
						|
 | 
						|
        ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None
 | 
						|
 | 
						|
 | 
						|
        bg_latents = None
 | 
						|
        if input_video != None:
 | 
						|
            pixel_value_bg = input_video.unsqueeze(0)
 | 
						|
            pixel_value_mask =  torch.zeros_like(input_video).unsqueeze(0)
 | 
						|
        if input_frames != None:
 | 
						|
            pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
 | 
						|
            pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
 | 
						|
            pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
 | 
						|
            if input_video != None:
 | 
						|
                pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
 | 
						|
                pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
 | 
						|
            else:
 | 
						|
                pixel_value_bg = pixel_value_video_bg
 | 
						|
                pixel_value_mask = pixel_value_video_mask
 | 
						|
            pixel_value_video_mask, pixel_value_video_bg  = None, None
 | 
						|
        if input_video != None or input_frames != None:
 | 
						|
            if pixel_value_bg.shape[2] < frame_num:
 | 
						|
                padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] +  list(pixel_value_bg.shape[3:])  
 | 
						|
                pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
 | 
						|
                pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
 | 
						|
 | 
						|
            bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()                
 | 
						|
            pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)             
 | 
						|
            mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
 | 
						|
            bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
 | 
						|
            bg_latents.mul_(self.vae.config.scaling_factor)
 | 
						|
 | 
						|
        if self.avatar:
 | 
						|
            if n_prompt == None or len(n_prompt) == 0:
 | 
						|
                n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
 | 
						|
 | 
						|
            uncond_pixel_value_llava = pixel_value_llava.clone()
 | 
						|
 | 
						|
            pixel_value_ref = pixel_value_ref.unsqueeze(0)
 | 
						|
            self.align_instance.facedet.model.to("cuda")
 | 
						|
            face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0) 
 | 
						|
            # iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8)
 | 
						|
            # image = Image.fromarray(iii)
 | 
						|
            # image.save("mask.png")
 | 
						|
            # jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8)
 | 
						|
 | 
						|
            self.align_instance.facedet.model.to("cpu")
 | 
						|
            # pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
 | 
						|
 | 
						|
            pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1)
 | 
						|
            pixel_value_ref = pixel_value_ref * 2 - 1 
 | 
						|
            pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
 | 
						|
 | 
						|
            vae_dtype = self.vae.dtype
 | 
						|
            with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
 | 
						|
                ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample()
 | 
						|
                ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1),  ref_latents[:,:, -1:]], dim=2)
 | 
						|
                pixel_value_ref, pixel_value_ref_for_vae = None, None
 | 
						|
 | 
						|
                if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
 | 
						|
                    ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
 | 
						|
                else:
 | 
						|
                    ref_latents.mul_(self.vae.config.scaling_factor)
 | 
						|
 | 
						|
                # out_latents= ref_latents / self.vae.config.scaling_factor
 | 
						|
                # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0]
 | 
						|
                # image = image.clamp(-1, 1)
 | 
						|
                # from wan.utils.utils import cache_video
 | 
						|
                # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
 | 
						|
 | 
						|
            motion_pose = np.array([25] * 4)
 | 
						|
            motion_exp = np.array([30] * 4)
 | 
						|
            motion_pose = torch.from_numpy(motion_pose).unsqueeze(0)
 | 
						|
            motion_exp = torch.from_numpy(motion_exp).unsqueeze(0)
 | 
						|
 | 
						|
            face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), 
 | 
						|
                                                    (ref_latents.shape[-2], 
 | 
						|
                                                    ref_latents.shape[-1]), 
 | 
						|
                                                    mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
 | 
						|
 | 
						|
 | 
						|
        if audio_guide != None:            
 | 
						|
            audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps )
 | 
						|
            audio_prompts = audio_input[0]
 | 
						|
            weight_dtype = audio_prompts.dtype
 | 
						|
            if self.custom:
 | 
						|
                audio_len = min(audio_len, frame_num)
 | 
						|
                audio_input = audio_input[:, :audio_len]
 | 
						|
            audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len) 
 | 
						|
            audio_prompts = audio_prompts.to(self.model.dtype)
 | 
						|
            segment_size = 129 if self.avatar else frame_num
 | 
						|
            if audio_prompts.shape[1] <= segment_size:
 | 
						|
                audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,segment_size-audio_prompts.shape[1], 1, 1, 1)], dim=1)
 | 
						|
            else:
 | 
						|
                audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
 | 
						|
            uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
 | 
						|
 | 
						|
        samples = self.pipeline(
 | 
						|
            prompt=input_prompt,
 | 
						|
            height=target_height,
 | 
						|
            width=target_width,
 | 
						|
            video_length=target_frame_num,
 | 
						|
            num_inference_steps=sampling_steps,
 | 
						|
            guidance_scale=guide_scale,
 | 
						|
            negative_prompt=n_prompt,
 | 
						|
            num_videos_per_prompt=num_videos_per_prompt,
 | 
						|
            generator=generator,
 | 
						|
            output_type="pil",
 | 
						|
            name = name,
 | 
						|
 | 
						|
            pixel_value_ref = pixel_value_ref,
 | 
						|
            ref_latents=ref_latents,                            # [1, 16, 1, h//8, w//8]
 | 
						|
            pixel_value_llava=pixel_value_llava,                # [1, 3, 336, 336]
 | 
						|
            uncond_pixel_value_llava=uncond_pixel_value_llava,
 | 
						|
            face_masks=face_masks,                              # [b f h w]
 | 
						|
            audio_prompts=audio_prompts, 
 | 
						|
            uncond_audio_prompts=uncond_audio_prompts, 
 | 
						|
            motion_exp=motion_exp, 
 | 
						|
            motion_pose=motion_pose, 
 | 
						|
            fps= torch.from_numpy(np.array(fps)), 
 | 
						|
 | 
						|
            bg_latents = bg_latents,
 | 
						|
            audio_strength = audio_strength,
 | 
						|
 | 
						|
            denoise_strength=denoise_strength,
 | 
						|
            ip_cfg_scale=ip_cfg_scale,             
 | 
						|
            freqs_cis=(freqs_cos, freqs_sin),
 | 
						|
            n_tokens=n_tokens,
 | 
						|
            embedded_guidance_scale=embedded_guidance_scale,
 | 
						|
            data_type="video" if target_frame_num > 1 else "image",
 | 
						|
            is_progress_bar=True,
 | 
						|
            vae_ver="884-16c-hy",
 | 
						|
            enable_tiling=True,
 | 
						|
            i2v_mode=i2v_mode,
 | 
						|
            i2v_condition_type=i2v_condition_type,
 | 
						|
            i2v_stability=i2v_stability,
 | 
						|
            img_latents=img_latents,
 | 
						|
            semantic_images=semantic_images,
 | 
						|
            joint_pass = joint_pass,
 | 
						|
            cfg_star_rescale = cfg_star_switch,
 | 
						|
            callback = callback,
 | 
						|
            callback_steps = callback_steps,
 | 
						|
        )[0]
 | 
						|
 | 
						|
        if samples == None:
 | 
						|
            return None
 | 
						|
        samples = samples.squeeze(0)
 | 
						|
 | 
						|
        return samples
 |