import os import time import random import functools from typing import List, Optional, Tuple, Union from pathlib import Path import torch import torch.distributed as dist from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V from hyvideo.vae import load_vae from hyvideo.modules import load_model from hyvideo.text_encoder import TextEncoder from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler from hyvideo.diffusion.pipelines import HunyuanVideoPipeline from PIL import Image import numpy as np import torchvision.transforms as transforms import cv2 def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): crop_h, crop_w = crop_img.shape[:2] target_w, target_h = size scale_h, scale_w = target_h / crop_h, target_w / crop_w if scale_w > scale_h: resize_h = int(target_h*resize_ratio) resize_w = int(crop_w / crop_h * resize_h) else: resize_w = int(target_w*resize_ratio) resize_h = int(crop_h / crop_w * resize_w) crop_img = cv2.resize(crop_img, (resize_w, resize_h)) pad_left = (target_w - resize_w) // 2 pad_top = (target_h - resize_h) // 2 pad_right = target_w - resize_w - pad_left pad_bottom = target_h - resize_h - pad_top crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) return crop_img def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_image_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) if labels is not None: final_labels = torch.full( (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_image_indices, text_to_overwrite = ( batch_indices.to(target_device), non_image_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) image_to_overwrite = torch.full( (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) indices_to_mask = new_token_positions[batch_indices, pad_indices] final_embedding[batch_indices, indices_to_mask] = 0 if labels is None: final_labels = None return final_embedding, final_attention_mask, final_labels, position_ids def patched_llava_forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ): from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) image_features = None if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] loss = None if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return LlavaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) class DataPreprocess(object): def __init__(self): self.llava_size = (336, 336) self.llava_transform = transforms.Compose( [ transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), ] ) def get_batch(self, image , size): image = np.asarray(image) llava_item_image = pad_image(image.copy(), self.llava_size) uncond_llava_item_image = np.ones_like(llava_item_image) * 255 cat_item_image = pad_image(image.copy(), size) llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 # batch = { # "pixel_value_llava": llava_item_tensor.unsqueeze(0), # "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), # 'pixel_value_ref': cat_item_tensor.unsqueeze(0), # } return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0) class Inference(object): def __init__( self, i2v, enable_cfg, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, device=None, ): self.i2v = i2v self.enable_cfg = enable_cfg self.vae = vae self.vae_kwargs = vae_kwargs self.text_encoder = text_encoder self.text_encoder_2 = text_encoder_2 self.model = model self.pipeline = pipeline self.device = "cuda" @classmethod def from_pretrained(cls, model_filepath, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , **kwargs): device = "cuda" import transformers transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features torch.set_grad_enabled(False) text_len = 512 latent_channels = 16 precision = "bf16" vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16" embedded_cfg_scale = 6 i2v_condition_type = None i2v_mode = "i2v" in model_filepath[0] custom = False if i2v_mode: model_id = "HYVideo-T/2" i2v_condition_type = "token_replace" elif "custom" in model_filepath[0]: model_id = "HYVideo-T/2-custom" custom = True else: model_id = "HYVideo-T/2-cfgdistill" if i2v_mode and i2v_condition_type == "latent_concat": in_channels = latent_channels * 2 + 1 image_embed_interleave = 2 elif i2v_mode and i2v_condition_type == "token_replace": in_channels = latent_channels image_embed_interleave = 4 else: in_channels = latent_channels image_embed_interleave = 1 out_channels = latent_channels pinToMemory = kwargs.pop("pinToMemory", False) partialPinning = kwargs.pop("partialPinning", False) factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]} if embedded_cfg_scale and i2v_mode: factor_kwargs["guidance_embed"] = True model = load_model( model = model_id, i2v_condition_type = i2v_condition_type, in_channels=in_channels, out_channels=out_channels, factor_kwargs=factor_kwargs, ) from mmgp import offload # model = Inference.load_state_dict(args, model, model_filepath) # model_filepath ="c:/temp/hc/mp_rank_00_model_states.pt" offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning) pass # offload.save_model(model, "hunyuan_video_custom_720_bf16.safetensors") # offload.save_model(model, "hunyuan_video_custom_720_quanto_bf16_int8.safetensors", do_quantize= True) model.mixed_precision = mixed_precision_transformer if model.mixed_precision : model._lock_dtype = torch.float32 model.lock_layers_dtypes(torch.float32) model.eval() # ============================= Build extra models ======================== # VAE if custom: vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json" vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors" else: vae_configpath = "ckpts/hunyuan_video_VAE_config.json" vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors" # config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json") # config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json") vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", ) vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16 vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} enable_cfg = False # Text encoder if i2v_mode: text_encoder = "llm-i2v" tokenizer = "llm-i2v" prompt_template = "dit-llm-encode-i2v" prompt_template_video = "dit-llm-encode-video-i2v" elif custom : text_encoder = "llm-i2v" tokenizer = "llm-i2v" prompt_template = "dit-llm-encode" prompt_template_video = "dit-llm-encode-video" enable_cfg = True else: text_encoder = "llm" tokenizer = "llm" prompt_template = "dit-llm-encode" prompt_template_video = "dit-llm-encode-video" if prompt_template_video is not None: crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 ) elif prompt_template is not None: crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0) else: crop_start = 0 max_length = text_len + crop_start # prompt_template prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None # prompt_template_video prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None text_encoder = TextEncoder( text_encoder_type=text_encoder, max_length=max_length, text_encoder_precision="fp16", tokenizer_type=tokenizer, i2v_mode=i2v_mode, prompt_template=prompt_template, prompt_template_video=prompt_template_video, hidden_state_skip_layer=2, apply_final_norm=False, reproduce=True, device="cpu", image_embed_interleave=image_embed_interleave, text_encoder_path = text_encoder_filepath ) text_encoder_2 = TextEncoder( text_encoder_type="clipL", max_length=77, text_encoder_precision="fp16", tokenizer_type="clipL", reproduce=True, device="cpu", ) return cls( i2v=i2v_mode, enable_cfg = enable_cfg, vae=vae, vae_kwargs=vae_kwargs, text_encoder=text_encoder, text_encoder_2=text_encoder_2, model=model, device=device, ) class HunyuanVideoSampler(Inference): def __init__( self, i2v, enable_cfg, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None, device=0, ): super().__init__( i2v, enable_cfg, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2, pipeline=pipeline, device=device, ) self.i2v_mode = i2v self.enable_cfg = enable_cfg self.pipeline = self.load_diffusion_pipeline( vae=self.vae, text_encoder=self.text_encoder, text_encoder_2=self.text_encoder_2, model=self.model, device=self.device, ) if self.i2v_mode: self.default_negative_prompt = NEGATIVE_PROMPT_I2V else: self.default_negative_prompt = NEGATIVE_PROMPT @property def _interrupt(self): return self.pipeline._interrupt @_interrupt.setter def _interrupt(self, value): self.pipeline._interrupt =value def load_diffusion_pipeline( self, vae, text_encoder, text_encoder_2, model, scheduler=None, device=None, progress_bar_config=None, #data_type="video", ): """Load the denoising scheduler for inference.""" if scheduler is None: scheduler = FlowMatchDiscreteScheduler( shift=6.0, reverse=True, solver="euler", ) pipeline = HunyuanVideoPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, transformer=model, scheduler=scheduler, progress_bar_config=progress_bar_config, ) return pipeline def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}): target_ndim = 3 ndim = 5 - 2 latents_size = [(video_length-1)//4+1 , height//8, width//8] if isinstance(self.model.patch_size, int): assert all(s % self.model.patch_size == 0 for s in latents_size), \ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // self.model.patch_size for s in latents_size] elif isinstance(self.model.patch_size, list): assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ f"but got {latents_size}." rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis head_dim = self.model.hidden_size // self.model.heads_num rope_dim_list = self.model.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, rope_sizes, theta=256, use_real=True, theta_rescale_factor=1, concat_dict=concat_dict) return freqs_cos, freqs_sin def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False): target_ndim = 3 ndim = 5 - 2 # 884 vae = "884-16c-hy" if "884" in vae: latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] elif "888" in vae: latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] else: latents_size = [video_length, height // 8, width // 8] if isinstance(self.model.patch_size, int): assert all(s % self.model.patch_size == 0 for s in latents_size), ( f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " f"but got {latents_size}." ) rope_sizes = [s // self.model.patch_size for s in latents_size] elif isinstance(self.model.patch_size, list): assert all( s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size) ), ( f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " f"but got {latents_size}." ) rope_sizes = [ s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) ] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis head_dim = self.model.hidden_size // self.model.heads_num rope_dim_list = self.model.rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert ( sum(rope_dim_list) == head_dim ), "sum(rope_dim_list) should equal to head_dim of attention layer" freqs_cos, freqs_sin = get_nd_rotary_pos_embed( rope_dim_list, rope_sizes, theta=256, use_real=True, theta_rescale_factor=1, L_test = (video_length - 1) // 4 + 1, enable_riflex = enable_riflex ) return freqs_cos, freqs_sin def generate( self, input_prompt, input_ref_images = None, height=192, width=336, frame_num=129, seed=None, n_prompt=None, sampling_steps=50, guide_scale=1.0, shift=5.0, embedded_guidance_scale=6.0, batch_size=1, num_videos_per_prompt=1, i2v_resolution="720p", image_start=None, enable_riflex = False, i2v_condition_type: str = "token_replace", i2v_stability=True, VAE_tile_size = None, joint_pass = False, cfg_star_switch = False, **kwargs, ): if VAE_tile_size != None: self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"] self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"] self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"] self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"] self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"] i2v_mode= self.i2v_mode if not self.enable_cfg: guide_scale=1.0 out_dict = dict() # ======================================================================== # Arguments: seed # ======================================================================== if isinstance(seed, torch.Tensor): seed = seed.tolist() if seed is None: seeds = [ random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt) ] elif isinstance(seed, int): seeds = [ seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt) ] elif isinstance(seed, (list, tuple)): if len(seed) == batch_size: seeds = [ int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt) ] elif len(seed) == batch_size * num_videos_per_prompt: seeds = [int(s) for s in seed] else: raise ValueError( f"Length of seed must be equal to number of prompt(batch_size) or " f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." ) else: raise ValueError( f"Seed must be an integer, a list of integers, or None, got {seed}." ) from wan.utils.utils import seed_everything seed_everything(seed) generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] out_dict["seeds"] = seeds # ======================================================================== # Arguments: target_width, target_height, target_frame_num # ======================================================================== if width <= 0 or height <= 0 or frame_num <= 0: raise ValueError( f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}" ) if (frame_num - 1) % 4 != 0: raise ValueError( f"`frame_num-1` must be a multiple of 4, got {frame_num}" ) target_height = align_to(height, 16) target_width = align_to(width, 16) target_frame_num = frame_num out_dict["size"] = (target_height, target_width, target_frame_num) if input_ref_images != None: # ip_cfg_scale = 3.0 ip_cfg_scale = 0 denoise_strength = 1 # guide_scale=7.5 # shift=13 name = "person" input_ref_images = input_ref_images[0] # ======================================================================== # Arguments: prompt, new_prompt, negative_prompt # ======================================================================== if not isinstance(input_prompt, str): raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}") input_prompt = [input_prompt.strip()] # negative prompt if n_prompt is None or n_prompt == "": n_prompt = self.default_negative_prompt if guide_scale == 1.0: n_prompt = "" if not isinstance(n_prompt, str): raise TypeError( f"`negative_prompt` must be a string, but got {type(n_prompt)}" ) n_prompt = [n_prompt.strip()] # ======================================================================== # Scheduler # ======================================================================== scheduler = FlowMatchDiscreteScheduler( shift=shift, reverse=True, solver="euler" ) self.pipeline.scheduler = scheduler # --------------------------------- # Reference condition # --------------------------------- img_latents = None semantic_images = None denoise_strength = 0 ip_cfg_scale = 0 if i2v_mode: if i2v_resolution == "720p": bucket_hw_base_size = 960 elif i2v_resolution == "540p": bucket_hw_base_size = 720 elif i2v_resolution == "360p": bucket_hw_base_size = 480 else: raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") # semantic_images = [Image.open(i2v_image_path).convert('RGB')] semantic_images = [image_start.convert('RGB')] # origin_size = semantic_images[0].size crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list]) closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) ref_image_transform = transforms.Compose([ transforms.Resize(closest_size), transforms.CenterCrop(closest_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W img_latents.mul_(self.pipeline.vae.config.scaling_factor) target_height, target_width = closest_size # ======================================================================== # Build Rope freqs # ======================================================================== if input_ref_images == None: freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_riflex) else: concat_dict = {'mode': 'timecat-w', 'bias': -1} freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict) n_tokens = freqs_cos.shape[0] callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) # ======================================================================== # Pipeline inference # ======================================================================== start_time = time.time() # "pixel_value_llava": llava_item_tensor.unsqueeze(0), # "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), # 'pixel_value_ref': cat_item_tensor.unsqueeze(0), if input_ref_images == None: pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None name = None else: pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height)) samples = self.pipeline( prompt=input_prompt, height=target_height, width=target_width, video_length=target_frame_num, num_inference_steps=sampling_steps, guidance_scale=guide_scale, negative_prompt=n_prompt, num_videos_per_prompt=num_videos_per_prompt, generator=generator, output_type="pil", name = name, pixel_value_llava = pixel_value_llava, uncond_pixel_value_llava=uncond_pixel_value_llava, pixel_value_ref=pixel_value_ref, denoise_strength=denoise_strength, ip_cfg_scale=ip_cfg_scale, freqs_cis=(freqs_cos, freqs_sin), n_tokens=n_tokens, embedded_guidance_scale=embedded_guidance_scale, data_type="video" if target_frame_num > 1 else "image", is_progress_bar=True, vae_ver="884-16c-hy", enable_tiling=True, i2v_mode=i2v_mode, i2v_condition_type=i2v_condition_type, i2v_stability=i2v_stability, img_latents=img_latents, semantic_images=semantic_images, joint_pass = joint_pass, cfg_star_rescale = cfg_star_switch, callback = callback, callback_steps = callback_steps, )[0] gen_time = time.time() - start_time if samples == None: return None samples = samples.sub_(0.5).mul_(2).squeeze(0) return samples