mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
831 lines
33 KiB
Python
831 lines
33 KiB
Python
import os
|
|
import time
|
|
import random
|
|
import functools
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
|
|
from hyvideo.vae import load_vae
|
|
from hyvideo.modules import load_model
|
|
from hyvideo.text_encoder import TextEncoder
|
|
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
|
|
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
|
|
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
|
|
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
|
|
from PIL import Image
|
|
import numpy as np
|
|
import torchvision.transforms as transforms
|
|
import cv2
|
|
|
|
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
|
|
crop_h, crop_w = crop_img.shape[:2]
|
|
target_w, target_h = size
|
|
scale_h, scale_w = target_h / crop_h, target_w / crop_w
|
|
if scale_w > scale_h:
|
|
resize_h = int(target_h*resize_ratio)
|
|
resize_w = int(crop_w / crop_h * resize_h)
|
|
else:
|
|
resize_w = int(target_w*resize_ratio)
|
|
resize_h = int(crop_h / crop_w * resize_w)
|
|
crop_img = cv2.resize(crop_img, (resize_w, resize_h))
|
|
pad_left = (target_w - resize_w) // 2
|
|
pad_top = (target_h - resize_h) // 2
|
|
pad_right = target_w - resize_w - pad_left
|
|
pad_bottom = target_h - resize_h - pad_top
|
|
crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
|
|
return crop_img
|
|
|
|
|
|
|
|
|
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
|
num_images, num_image_patches, embed_dim = image_features.shape
|
|
batch_size, sequence_length = input_ids.shape
|
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
|
# 1. Create a mask to know where special image tokens are
|
|
special_image_token_mask = input_ids == self.config.image_token_index
|
|
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
|
# Compute the maximum embed dimension
|
|
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
|
|
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
|
|
|
|
# 2. Compute the positions where text should be written
|
|
# Calculate new positions for text tokens in merged image-text sequence.
|
|
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
|
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
|
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
|
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
|
|
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
|
if left_padding:
|
|
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
|
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
|
|
|
# 3. Create the full embedding, already padded to the maximum position
|
|
final_embedding = torch.zeros(
|
|
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
|
)
|
|
final_attention_mask = torch.zeros(
|
|
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
|
|
)
|
|
if labels is not None:
|
|
final_labels = torch.full(
|
|
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
|
)
|
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
|
# set the corresponding tensors into their correct target device.
|
|
target_device = inputs_embeds.device
|
|
batch_indices, non_image_indices, text_to_overwrite = (
|
|
batch_indices.to(target_device),
|
|
non_image_indices.to(target_device),
|
|
text_to_overwrite.to(target_device),
|
|
)
|
|
attention_mask = attention_mask.to(target_device)
|
|
|
|
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
|
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
|
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
|
|
if labels is not None:
|
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
|
|
|
|
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
|
image_to_overwrite = torch.full(
|
|
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
|
)
|
|
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
|
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
|
|
|
|
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
|
raise ValueError(
|
|
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
|
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
|
)
|
|
|
|
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
|
final_attention_mask |= image_to_overwrite
|
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
|
|
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
|
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
|
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
|
|
|
final_embedding[batch_indices, indices_to_mask] = 0
|
|
|
|
if labels is None:
|
|
final_labels = None
|
|
|
|
return final_embedding, final_attention_mask, final_labels, position_ids
|
|
|
|
def patched_llava_forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
vision_feature_layer: Optional[int] = None,
|
|
vision_feature_select_strategy: Optional[str] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
num_logits_to_keep: int = 0,
|
|
):
|
|
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
|
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
vision_feature_layer = (
|
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
|
)
|
|
vision_feature_select_strategy = (
|
|
vision_feature_select_strategy
|
|
if vision_feature_select_strategy is not None
|
|
else self.config.vision_feature_select_strategy
|
|
)
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if pixel_values is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
|
|
image_features = None
|
|
if pixel_values is not None:
|
|
image_features = self.get_image_features(
|
|
pixel_values=pixel_values,
|
|
vision_feature_layer=vision_feature_layer,
|
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
)
|
|
|
|
|
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
|
image_features, inputs_embeds, input_ids, attention_mask, labels
|
|
)
|
|
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
|
|
|
|
|
|
outputs = self.language_model(
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
num_logits_to_keep=num_logits_to_keep,
|
|
)
|
|
|
|
logits = outputs[0]
|
|
|
|
loss = None
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return LlavaCausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
image_hidden_states=image_features if pixel_values is not None else None,
|
|
)
|
|
|
|
class DataPreprocess(object):
|
|
def __init__(self):
|
|
self.llava_size = (336, 336)
|
|
self.llava_transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
|
|
]
|
|
)
|
|
|
|
def get_batch(self, image , size):
|
|
image = np.asarray(image)
|
|
llava_item_image = pad_image(image.copy(), self.llava_size)
|
|
uncond_llava_item_image = np.ones_like(llava_item_image) * 255
|
|
cat_item_image = pad_image(image.copy(), size)
|
|
|
|
llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
|
|
uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
|
|
cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
|
|
# batch = {
|
|
# "pixel_value_llava": llava_item_tensor.unsqueeze(0),
|
|
# "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
|
|
# 'pixel_value_ref': cat_item_tensor.unsqueeze(0),
|
|
# }
|
|
return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0)
|
|
|
|
class Inference(object):
|
|
def __init__(
|
|
self,
|
|
i2v,
|
|
enable_cfg,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=None,
|
|
pipeline=None,
|
|
device=None,
|
|
):
|
|
self.i2v = i2v
|
|
self.enable_cfg = enable_cfg
|
|
self.vae = vae
|
|
self.vae_kwargs = vae_kwargs
|
|
|
|
self.text_encoder = text_encoder
|
|
self.text_encoder_2 = text_encoder_2
|
|
|
|
self.model = model
|
|
self.pipeline = pipeline
|
|
|
|
self.device = "cuda"
|
|
|
|
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_filepath, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , **kwargs):
|
|
|
|
device = "cuda"
|
|
|
|
import transformers
|
|
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
|
|
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
|
|
|
|
torch.set_grad_enabled(False)
|
|
text_len = 512
|
|
latent_channels = 16
|
|
precision = "bf16"
|
|
vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16"
|
|
embedded_cfg_scale = 6
|
|
i2v_condition_type = None
|
|
i2v_mode = "i2v" in model_filepath[0]
|
|
custom = False
|
|
if i2v_mode:
|
|
model_id = "HYVideo-T/2"
|
|
i2v_condition_type = "token_replace"
|
|
elif "custom" in model_filepath[0]:
|
|
model_id = "HYVideo-T/2-custom"
|
|
custom = True
|
|
else:
|
|
model_id = "HYVideo-T/2-cfgdistill"
|
|
|
|
if i2v_mode and i2v_condition_type == "latent_concat":
|
|
in_channels = latent_channels * 2 + 1
|
|
image_embed_interleave = 2
|
|
elif i2v_mode and i2v_condition_type == "token_replace":
|
|
in_channels = latent_channels
|
|
image_embed_interleave = 4
|
|
else:
|
|
in_channels = latent_channels
|
|
image_embed_interleave = 1
|
|
out_channels = latent_channels
|
|
pinToMemory = kwargs.pop("pinToMemory", False)
|
|
partialPinning = kwargs.pop("partialPinning", False)
|
|
factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]}
|
|
|
|
if embedded_cfg_scale and i2v_mode:
|
|
factor_kwargs["guidance_embed"] = True
|
|
|
|
model = load_model(
|
|
model = model_id,
|
|
i2v_condition_type = i2v_condition_type,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
factor_kwargs=factor_kwargs,
|
|
)
|
|
|
|
|
|
from mmgp import offload
|
|
# model = Inference.load_state_dict(args, model, model_filepath)
|
|
|
|
# model_filepath ="c:/temp/hc/mp_rank_00_model_states.pt"
|
|
offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning)
|
|
pass
|
|
# offload.save_model(model, "hunyuan_video_custom_720_bf16.safetensors")
|
|
# offload.save_model(model, "hunyuan_video_custom_720_quanto_bf16_int8.safetensors", do_quantize= True)
|
|
|
|
model.mixed_precision = mixed_precision_transformer
|
|
|
|
if model.mixed_precision :
|
|
model._lock_dtype = torch.float32
|
|
model.lock_layers_dtypes(torch.float32)
|
|
model.eval()
|
|
|
|
# ============================= Build extra models ========================
|
|
# VAE
|
|
if custom:
|
|
vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json"
|
|
vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors"
|
|
else:
|
|
vae_configpath = "ckpts/hunyuan_video_VAE_config.json"
|
|
vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors"
|
|
|
|
# config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json")
|
|
# config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json")
|
|
|
|
vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", )
|
|
|
|
vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16
|
|
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
|
enable_cfg = False
|
|
# Text encoder
|
|
if i2v_mode:
|
|
text_encoder = "llm-i2v"
|
|
tokenizer = "llm-i2v"
|
|
prompt_template = "dit-llm-encode-i2v"
|
|
prompt_template_video = "dit-llm-encode-video-i2v"
|
|
elif custom :
|
|
text_encoder = "llm-i2v"
|
|
tokenizer = "llm-i2v"
|
|
prompt_template = "dit-llm-encode"
|
|
prompt_template_video = "dit-llm-encode-video"
|
|
enable_cfg = True
|
|
else:
|
|
text_encoder = "llm"
|
|
tokenizer = "llm"
|
|
prompt_template = "dit-llm-encode"
|
|
prompt_template_video = "dit-llm-encode-video"
|
|
|
|
if prompt_template_video is not None:
|
|
crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 )
|
|
elif prompt_template is not None:
|
|
crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0)
|
|
else:
|
|
crop_start = 0
|
|
max_length = text_len + crop_start
|
|
|
|
# prompt_template
|
|
prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None
|
|
|
|
# prompt_template_video
|
|
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None
|
|
|
|
|
|
text_encoder = TextEncoder(
|
|
text_encoder_type=text_encoder,
|
|
max_length=max_length,
|
|
text_encoder_precision="fp16",
|
|
tokenizer_type=tokenizer,
|
|
i2v_mode=i2v_mode,
|
|
prompt_template=prompt_template,
|
|
prompt_template_video=prompt_template_video,
|
|
hidden_state_skip_layer=2,
|
|
apply_final_norm=False,
|
|
reproduce=True,
|
|
device="cpu",
|
|
image_embed_interleave=image_embed_interleave,
|
|
text_encoder_path = text_encoder_filepath
|
|
)
|
|
|
|
text_encoder_2 = TextEncoder(
|
|
text_encoder_type="clipL",
|
|
max_length=77,
|
|
text_encoder_precision="fp16",
|
|
tokenizer_type="clipL",
|
|
reproduce=True,
|
|
device="cpu",
|
|
)
|
|
|
|
return cls(
|
|
i2v=i2v_mode,
|
|
enable_cfg = enable_cfg,
|
|
vae=vae,
|
|
vae_kwargs=vae_kwargs,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
model=model,
|
|
device=device,
|
|
)
|
|
|
|
|
|
|
|
class HunyuanVideoSampler(Inference):
|
|
def __init__(
|
|
self,
|
|
i2v,
|
|
enable_cfg,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=None,
|
|
pipeline=None,
|
|
device=0,
|
|
):
|
|
super().__init__(
|
|
i2v,
|
|
enable_cfg,
|
|
vae,
|
|
vae_kwargs,
|
|
text_encoder,
|
|
model,
|
|
text_encoder_2=text_encoder_2,
|
|
pipeline=pipeline,
|
|
device=device,
|
|
)
|
|
|
|
self.i2v_mode = i2v
|
|
self.enable_cfg = enable_cfg
|
|
self.pipeline = self.load_diffusion_pipeline(
|
|
vae=self.vae,
|
|
text_encoder=self.text_encoder,
|
|
text_encoder_2=self.text_encoder_2,
|
|
model=self.model,
|
|
device=self.device,
|
|
)
|
|
|
|
if self.i2v_mode:
|
|
self.default_negative_prompt = NEGATIVE_PROMPT_I2V
|
|
else:
|
|
self.default_negative_prompt = NEGATIVE_PROMPT
|
|
|
|
@property
|
|
def _interrupt(self):
|
|
return self.pipeline._interrupt
|
|
|
|
@_interrupt.setter
|
|
def _interrupt(self, value):
|
|
self.pipeline._interrupt =value
|
|
|
|
def load_diffusion_pipeline(
|
|
self,
|
|
vae,
|
|
text_encoder,
|
|
text_encoder_2,
|
|
model,
|
|
scheduler=None,
|
|
device=None,
|
|
progress_bar_config=None,
|
|
#data_type="video",
|
|
):
|
|
"""Load the denoising scheduler for inference."""
|
|
if scheduler is None:
|
|
scheduler = FlowMatchDiscreteScheduler(
|
|
shift=6.0,
|
|
reverse=True,
|
|
solver="euler",
|
|
)
|
|
|
|
pipeline = HunyuanVideoPipeline(
|
|
vae=vae,
|
|
text_encoder=text_encoder,
|
|
text_encoder_2=text_encoder_2,
|
|
transformer=model,
|
|
scheduler=scheduler,
|
|
progress_bar_config=progress_bar_config,
|
|
)
|
|
|
|
return pipeline
|
|
|
|
def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}):
|
|
target_ndim = 3
|
|
ndim = 5 - 2
|
|
latents_size = [(video_length-1)//4+1 , height//8, width//8]
|
|
|
|
if isinstance(self.model.patch_size, int):
|
|
assert all(s % self.model.patch_size == 0 for s in latents_size), \
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
|
f"but got {latents_size}."
|
|
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
|
elif isinstance(self.model.patch_size, list):
|
|
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
|
f"but got {latents_size}."
|
|
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
|
|
|
if len(rope_sizes) != target_ndim:
|
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
|
head_dim = self.model.hidden_size // self.model.heads_num
|
|
rope_dim_list = self.model.rope_dim_list
|
|
if rope_dim_list is None:
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
|
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
|
|
rope_sizes,
|
|
theta=256,
|
|
use_real=True,
|
|
theta_rescale_factor=1,
|
|
concat_dict=concat_dict)
|
|
return freqs_cos, freqs_sin
|
|
|
|
def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False):
|
|
target_ndim = 3
|
|
ndim = 5 - 2
|
|
# 884
|
|
vae = "884-16c-hy"
|
|
if "884" in vae:
|
|
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
|
elif "888" in vae:
|
|
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
|
else:
|
|
latents_size = [video_length, height // 8, width // 8]
|
|
|
|
if isinstance(self.model.patch_size, int):
|
|
assert all(s % self.model.patch_size == 0 for s in latents_size), (
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
|
f"but got {latents_size}."
|
|
)
|
|
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
|
elif isinstance(self.model.patch_size, list):
|
|
assert all(
|
|
s % self.model.patch_size[idx] == 0
|
|
for idx, s in enumerate(latents_size)
|
|
), (
|
|
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
|
|
f"but got {latents_size}."
|
|
)
|
|
rope_sizes = [
|
|
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
|
|
]
|
|
|
|
if len(rope_sizes) != target_ndim:
|
|
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
|
head_dim = self.model.hidden_size // self.model.heads_num
|
|
rope_dim_list = self.model.rope_dim_list
|
|
if rope_dim_list is None:
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
|
assert (
|
|
sum(rope_dim_list) == head_dim
|
|
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
|
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
|
rope_dim_list,
|
|
rope_sizes,
|
|
theta=256,
|
|
use_real=True,
|
|
theta_rescale_factor=1,
|
|
L_test = (video_length - 1) // 4 + 1,
|
|
enable_riflex = enable_riflex
|
|
)
|
|
return freqs_cos, freqs_sin
|
|
|
|
|
|
def generate(
|
|
self,
|
|
input_prompt,
|
|
input_ref_images = None,
|
|
height=192,
|
|
width=336,
|
|
frame_num=129,
|
|
seed=None,
|
|
n_prompt=None,
|
|
sampling_steps=50,
|
|
guide_scale=1.0,
|
|
shift=5.0,
|
|
embedded_guidance_scale=6.0,
|
|
batch_size=1,
|
|
num_videos_per_prompt=1,
|
|
i2v_resolution="720p",
|
|
image_start=None,
|
|
enable_riflex = False,
|
|
i2v_condition_type: str = "token_replace",
|
|
i2v_stability=True,
|
|
VAE_tile_size = None,
|
|
joint_pass = False,
|
|
cfg_star_switch = False,
|
|
**kwargs,
|
|
):
|
|
|
|
if VAE_tile_size != None:
|
|
self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"]
|
|
self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"]
|
|
self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"]
|
|
self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"]
|
|
self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"]
|
|
|
|
i2v_mode= self.i2v_mode
|
|
if not self.enable_cfg:
|
|
guide_scale=1.0
|
|
|
|
|
|
out_dict = dict()
|
|
|
|
# ========================================================================
|
|
# Arguments: seed
|
|
# ========================================================================
|
|
if isinstance(seed, torch.Tensor):
|
|
seed = seed.tolist()
|
|
if seed is None:
|
|
seeds = [
|
|
random.randint(0, 1_000_000)
|
|
for _ in range(batch_size * num_videos_per_prompt)
|
|
]
|
|
elif isinstance(seed, int):
|
|
seeds = [
|
|
seed + i
|
|
for _ in range(batch_size)
|
|
for i in range(num_videos_per_prompt)
|
|
]
|
|
elif isinstance(seed, (list, tuple)):
|
|
if len(seed) == batch_size:
|
|
seeds = [
|
|
int(seed[i]) + j
|
|
for i in range(batch_size)
|
|
for j in range(num_videos_per_prompt)
|
|
]
|
|
elif len(seed) == batch_size * num_videos_per_prompt:
|
|
seeds = [int(s) for s in seed]
|
|
else:
|
|
raise ValueError(
|
|
f"Length of seed must be equal to number of prompt(batch_size) or "
|
|
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Seed must be an integer, a list of integers, or None, got {seed}."
|
|
)
|
|
from wan.utils.utils import seed_everything
|
|
seed_everything(seed)
|
|
generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
|
|
# generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
|
out_dict["seeds"] = seeds
|
|
|
|
# ========================================================================
|
|
# Arguments: target_width, target_height, target_frame_num
|
|
# ========================================================================
|
|
if width <= 0 or height <= 0 or frame_num <= 0:
|
|
raise ValueError(
|
|
f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}"
|
|
)
|
|
if (frame_num - 1) % 4 != 0:
|
|
raise ValueError(
|
|
f"`frame_num-1` must be a multiple of 4, got {frame_num}"
|
|
)
|
|
|
|
target_height = align_to(height, 16)
|
|
target_width = align_to(width, 16)
|
|
target_frame_num = frame_num
|
|
|
|
out_dict["size"] = (target_height, target_width, target_frame_num)
|
|
|
|
if input_ref_images != None:
|
|
# ip_cfg_scale = 3.0
|
|
ip_cfg_scale = 0
|
|
denoise_strength = 1
|
|
# guide_scale=7.5
|
|
# shift=13
|
|
name = "person"
|
|
input_ref_images = input_ref_images[0]
|
|
|
|
# ========================================================================
|
|
# Arguments: prompt, new_prompt, negative_prompt
|
|
# ========================================================================
|
|
if not isinstance(input_prompt, str):
|
|
raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}")
|
|
input_prompt = [input_prompt.strip()]
|
|
|
|
# negative prompt
|
|
if n_prompt is None or n_prompt == "":
|
|
n_prompt = self.default_negative_prompt
|
|
if guide_scale == 1.0:
|
|
n_prompt = ""
|
|
if not isinstance(n_prompt, str):
|
|
raise TypeError(
|
|
f"`negative_prompt` must be a string, but got {type(n_prompt)}"
|
|
)
|
|
n_prompt = [n_prompt.strip()]
|
|
|
|
# ========================================================================
|
|
# Scheduler
|
|
# ========================================================================
|
|
scheduler = FlowMatchDiscreteScheduler(
|
|
shift=shift,
|
|
reverse=True,
|
|
solver="euler"
|
|
)
|
|
self.pipeline.scheduler = scheduler
|
|
|
|
# ---------------------------------
|
|
# Reference condition
|
|
# ---------------------------------
|
|
img_latents = None
|
|
semantic_images = None
|
|
denoise_strength = 0
|
|
ip_cfg_scale = 0
|
|
if i2v_mode:
|
|
if i2v_resolution == "720p":
|
|
bucket_hw_base_size = 960
|
|
elif i2v_resolution == "540p":
|
|
bucket_hw_base_size = 720
|
|
elif i2v_resolution == "360p":
|
|
bucket_hw_base_size = 480
|
|
else:
|
|
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
|
|
|
# semantic_images = [Image.open(i2v_image_path).convert('RGB')]
|
|
semantic_images = [image_start.convert('RGB')] #
|
|
|
|
origin_size = semantic_images[0].size
|
|
|
|
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
|
|
aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list])
|
|
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
|
ref_image_transform = transforms.Compose([
|
|
transforms.Resize(closest_size),
|
|
transforms.CenterCrop(closest_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5])
|
|
])
|
|
|
|
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
|
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
|
img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W
|
|
img_latents.mul_(self.pipeline.vae.config.scaling_factor)
|
|
|
|
target_height, target_width = closest_size
|
|
|
|
# ========================================================================
|
|
# Build Rope freqs
|
|
# ========================================================================
|
|
|
|
if input_ref_images == None:
|
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_riflex)
|
|
else:
|
|
concat_dict = {'mode': 'timecat-w', 'bias': -1}
|
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict)
|
|
|
|
n_tokens = freqs_cos.shape[0]
|
|
|
|
|
|
callback = kwargs.pop("callback", None)
|
|
callback_steps = kwargs.pop("callback_steps", None)
|
|
# ========================================================================
|
|
# Pipeline inference
|
|
# ========================================================================
|
|
start_time = time.time()
|
|
|
|
|
|
# "pixel_value_llava": llava_item_tensor.unsqueeze(0),
|
|
# "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
|
|
# 'pixel_value_ref': cat_item_tensor.unsqueeze(0),
|
|
if input_ref_images == None:
|
|
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None
|
|
name = None
|
|
else:
|
|
pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height))
|
|
samples = self.pipeline(
|
|
prompt=input_prompt,
|
|
height=target_height,
|
|
width=target_width,
|
|
video_length=target_frame_num,
|
|
num_inference_steps=sampling_steps,
|
|
guidance_scale=guide_scale,
|
|
negative_prompt=n_prompt,
|
|
num_videos_per_prompt=num_videos_per_prompt,
|
|
generator=generator,
|
|
output_type="pil",
|
|
name = name,
|
|
pixel_value_llava = pixel_value_llava,
|
|
uncond_pixel_value_llava=uncond_pixel_value_llava,
|
|
pixel_value_ref=pixel_value_ref,
|
|
denoise_strength=denoise_strength,
|
|
ip_cfg_scale=ip_cfg_scale,
|
|
freqs_cis=(freqs_cos, freqs_sin),
|
|
n_tokens=n_tokens,
|
|
embedded_guidance_scale=embedded_guidance_scale,
|
|
data_type="video" if target_frame_num > 1 else "image",
|
|
is_progress_bar=True,
|
|
vae_ver="884-16c-hy",
|
|
enable_tiling=True,
|
|
i2v_mode=i2v_mode,
|
|
i2v_condition_type=i2v_condition_type,
|
|
i2v_stability=i2v_stability,
|
|
img_latents=img_latents,
|
|
semantic_images=semantic_images,
|
|
joint_pass = joint_pass,
|
|
cfg_star_rescale = cfg_star_switch,
|
|
callback = callback,
|
|
callback_steps = callback_steps,
|
|
)[0]
|
|
gen_time = time.time() - start_time
|
|
if samples == None:
|
|
return None
|
|
samples = samples.sub_(0.5).mul_(2).squeeze(0)
|
|
|
|
return samples
|