diff --git a/README.md b/README.md index ea24df7..3b7fbd5 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models ## 🔥 Latest News!! -* May 26 2025: 👋 WanGP v5.31 : Added Phantom 14B, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets. +* May 28 2025: 👋 WanGP v5.31 : Added Phantom 14B, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets. * May 26 2025: 👋 WanGP v5.3 : Happy with a Video generation and want to do more generations using the same settings but you can't remember what you did or you find it to hard to copy / paste one per one each setting from the file metadata ? Rejoice ! There are now multiple ways to turn this tedious process into a one click task: - Select one Video recently generated in the Video Gallery and click *Use Selected Video Settings* - Click *Drop File Here* and select a Video you saved somewhere, if the settings metadata have been saved with the Video you will be able to extract them automatically diff --git a/hyvideo/data_kits/audio_dataset.py b/hyvideo/data_kits/audio_dataset.py new file mode 100644 index 0000000..e65cee5 --- /dev/null +++ b/hyvideo/data_kits/audio_dataset.py @@ -0,0 +1,170 @@ +import os +import cv2 +import math +import json +import torch +import random +import librosa +import traceback +import torchvision +import numpy as np +import pandas as pd +from PIL import Image +from einops import rearrange +from torch.utils.data import Dataset +from decord import VideoReader, cpu +from transformers import CLIPImageProcessor +import torchvision.transforms as transforms +from torchvision.transforms import ToPILImage + + + +def get_audio_feature(feature_extractor, audio_path): + audio_input, sampling_rate = librosa.load(audio_path, sr=16000) + assert sampling_rate == 16000 + + audio_features = [] + window = 750*640 + for i in range(0, len(audio_input), window): + audio_feature = feature_extractor(audio_input[i:i+window], + sampling_rate=sampling_rate, + return_tensors="pt", + ).input_features + audio_features.append(audio_feature) + + audio_features = torch.cat(audio_features, dim=-1) + return audio_features, len(audio_input) // 640 + + +class VideoAudioTextLoaderVal(Dataset): + def __init__( + self, + image_size: int, + meta_file: str, + **kwargs, + ): + super().__init__() + self.meta_file = meta_file + self.image_size = image_size + self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder + self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder + self.feature_extractor = kwargs.get("feature_extractor", None) + self.meta_files = [] + + csv_data = pd.read_csv(meta_file) + for idx in range(len(csv_data)): + self.meta_files.append( + { + "videoid": str(csv_data["videoid"][idx]), + "image_path": str(csv_data["image"][idx]), + "audio_path": str(csv_data["audio"][idx]), + "prompt": str(csv_data["prompt"][idx]), + "fps": float(csv_data["fps"][idx]) + } + ) + + self.llava_transform = transforms.Compose( + [ + transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + self.clip_image_processor = CLIPImageProcessor() + + self.device = torch.device("cuda") + self.weight_dtype = torch.float16 + + + def __len__(self): + return len(self.meta_files) + + @staticmethod + def get_text_tokens(text_encoder, description, dtype_encode="video"): + text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode) + text_ids = text_inputs["input_ids"].squeeze(0) + text_mask = text_inputs["attention_mask"].squeeze(0) + return text_ids, text_mask + + def get_batch_data(self, idx): + meta_file = self.meta_files[idx] + videoid = meta_file["videoid"] + image_path = meta_file["image_path"] + audio_path = meta_file["audio_path"] + prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"] + fps = meta_file["fps"] + + img_size = self.image_size + ref_image = Image.open(image_path).convert('RGB') + + # Resize reference image + w, h = ref_image.size + scale = img_size / min(w, h) + new_w = round(w * scale / 64) * 64 + new_h = round(h * scale / 64) * 64 + + if img_size == 704: + img_size_long = 1216 + if new_w * new_h > img_size * img_size_long: + import math + scale = math.sqrt(img_size * img_size_long / w / h) + new_w = round(w * scale / 64) * 64 + new_h = round(h * scale / 64) * 64 + + ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS) + + ref_image = np.array(ref_image) + ref_image = torch.from_numpy(ref_image) + + audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path) + audio_prompts = audio_input[0] + + motion_bucket_id_heads = np.array([25] * 4) + motion_bucket_id_exps = np.array([30] * 4) + motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads) + motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps) + fps = torch.from_numpy(np.array(fps)) + + to_pil = ToPILImage() + pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w) + + pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref] + pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0) + pixel_value_ref_clip = self.clip_image_processor( + images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)), + return_tensors="pt" + ).pixel_values[0] + pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0) + + # Encode text prompts + + text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt) + text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt) + + # Output batch + batch = { + "text_prompt": prompt, # + "videoid": videoid, + "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255) + "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围 + "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围 + "audio_prompts": audio_prompts.to(dtype=torch.float16), + "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype), + "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype), + "fps": fps.to(dtype=torch.float16), + "text_ids": text_ids.clone(), # 对应llava_text_encoder + "text_mask": text_mask.clone(), # 对应llava_text_encoder + "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder + "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder + "audio_len": audio_len, + "image_path": image_path, + "audio_path": audio_path, + } + return batch + + def __getitem__(self, idx): + return self.get_batch_data(idx) + + + + \ No newline at end of file diff --git a/hyvideo/data_kits/audio_preprocessor.py b/hyvideo/data_kits/audio_preprocessor.py new file mode 100644 index 0000000..89ac470 --- /dev/null +++ b/hyvideo/data_kits/audio_preprocessor.py @@ -0,0 +1,72 @@ + +import os +import cv2 +import json +import time +import decord +import einops +import librosa +import torch +import random +import argparse +import traceback +import numpy as np +from tqdm import tqdm +from PIL import Image +from einops import rearrange + + + +def get_facemask(ref_image, align_instance, area=1.25): + # ref_image: (b f c h w) + bsz, f, c, h, w = ref_image.shape + images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) + face_masks = [] + for image in images: + image_pil = Image.fromarray(image).convert("RGB") + _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) + try: + bboxSrc = bboxes_list[0] + except: + bboxSrc = [0, 0, w, h] + x1, y1, ww, hh = bboxSrc + x2, y2 = x1 + ww, y1 + hh + ww, hh = (x2-x1) * area, (y2-y1) * area + center = [(x2+x1)//2, (y2+y1)//2] + x1 = max(center[0] - ww//2, 0) + y1 = max(center[1] - hh//2, 0) + x2 = min(center[0] + ww//2, w) + y2 = min(center[1] + hh//2, h) + + face_mask = np.zeros_like(np.array(image_pil)) + face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 + face_masks.append(torch.from_numpy(face_mask[...,:1])) + face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c) + face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) + face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) + return face_masks + + +def encode_audio(wav2vec, audio_feats, fps, num_frames=129): + if fps == 25: + start_ts = [0] + step_ts = [1] + elif fps == 12.5: + start_ts = [0] + step_ts = [2] + num_frames = min(num_frames, 400) + audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states + audio_feats = torch.stack(audio_feats, dim=2) + audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) + + audio_prompts = [] + for bb in range(1): + audio_feats_list = [] + for f in range(num_frames): + cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 + audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] + audio_feats_list.append(audio_clip) + audio_feats_list = torch.stack(audio_feats_list, 1) + audio_prompts.append(audio_feats_list) + audio_prompts = torch.cat(audio_prompts) + return audio_prompts \ No newline at end of file diff --git a/hyvideo/data_kits/data_tools.py b/hyvideo/data_kits/data_tools.py new file mode 100644 index 0000000..a7d6077 --- /dev/null +++ b/hyvideo/data_kits/data_tools.py @@ -0,0 +1,41 @@ +import os +import cv2 +import torch +import numpy as np +import imageio +import torchvision +from einops import rearrange + + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = torch.clamp(x,0,1) + x = (x * 255).numpy().astype(np.uint8) + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + imageio.mimsave(path, outputs, fps=fps, quality=quality) + +def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): + crop_h, crop_w = crop_img.shape[:2] + target_w, target_h = size + scale_h, scale_w = target_h / crop_h, target_w / crop_w + if scale_w > scale_h: + resize_h = int(target_h*resize_ratio) + resize_w = int(crop_w / crop_h * resize_h) + else: + resize_w = int(target_w*resize_ratio) + resize_h = int(crop_h / crop_w * resize_w) + crop_img = cv2.resize(crop_img, (resize_w, resize_h)) + pad_left = (target_w - resize_w) // 2 + pad_top = (target_h - resize_h) // 2 + pad_right = target_w - resize_w - pad_left + pad_bottom = target_h - resize_h - pad_top + crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) + return crop_img \ No newline at end of file diff --git a/hyvideo/data_kits/face_align/__init__.py b/hyvideo/data_kits/face_align/__init__.py new file mode 100644 index 0000000..d660022 --- /dev/null +++ b/hyvideo/data_kits/face_align/__init__.py @@ -0,0 +1 @@ +from .align import AlignImage \ No newline at end of file diff --git a/hyvideo/data_kits/face_align/align.py b/hyvideo/data_kits/face_align/align.py new file mode 100644 index 0000000..610c441 --- /dev/null +++ b/hyvideo/data_kits/face_align/align.py @@ -0,0 +1,34 @@ +import os +import sys +import torch +from .detface import DetFace + +class AlignImage(object): + def __init__(self, device='cuda', det_path=''): + self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device) + + @torch.no_grad() + def __call__(self, im, maxface=False): + bboxes, kpss, scores = self.facedet.detect(im) + face_num = bboxes.shape[0] + + five_pts_list = [] + scores_list = [] + bboxes_list = [] + for i in range(face_num): + five_pts_list.append(kpss[i].reshape(5,2)) + scores_list.append(scores[i]) + bboxes_list.append(bboxes[i]) + + if maxface and face_num>1: + max_idx = 0 + max_area = (bboxes[0, 2])*(bboxes[0, 3]) + for i in range(1, face_num): + area = (bboxes[i,2])*(bboxes[i,3]) + if area>max_area: + max_idx = i + five_pts_list = [five_pts_list[max_idx]] + scores_list = [scores_list[max_idx]] + bboxes_list = [bboxes_list[max_idx]] + + return five_pts_list, scores_list, bboxes_list \ No newline at end of file diff --git a/hyvideo/data_kits/face_align/detface.py b/hyvideo/data_kits/face_align/detface.py new file mode 100644 index 0000000..d04d293 --- /dev/null +++ b/hyvideo/data_kits/face_align/detface.py @@ -0,0 +1,283 @@ +# -*- coding: UTF-8 -*- +import os +import cv2 +import numpy as np +import torch +import torchvision + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + # box = 4xn + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - + torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + # iou = inter / (area1 + area2 - inter) + return inter / (area1[:, None] + area2 - inter) + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + #clip_coords(coords, img0_shape) + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords + + +def show_results(img, xywh, conf, landmarks, class_num): + h,w,c = img.shape + tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness + x1 = int(xywh[0] * w - 0.5 * xywh[2] * w) + y1 = int(xywh[1] * h - 0.5 * xywh[3] * h) + x2 = int(xywh[0] * w + 0.5 * xywh[2] * w) + y2 = int(xywh[1] * h + 0.5 * xywh[3] * h) + cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA) + + clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)] + + for i in range(5): + point_x = int(landmarks[2 * i] * w) + point_y = int(landmarks[2 * i + 1] * h) + cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1) + + tf = max(tl - 1, 1) # font thickness + label = str(conf)[:5] + cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) + return img + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return (x // divisor) * divisor + + +def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height + # time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + # t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + l = labels[xi] + v = torch.zeros((len(l), nc + 15), device=x.device) + v[:, :4] = l[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + #if i.shape[0] > max_det: # limit detections + # i = i[:max_det] + if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + # if (time.time() - t) > time_limit: + # break # time limit exceeded + + return output + + +class DetFace(): + def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'): + assert os.path.exists(pt_path) + + self.inpSize = 416 + self.conf_thres = confThreshold + self.iou_thres = nmsThreshold + self.test_device = torch.device(device if torch.cuda.is_available() else "cpu") + self.model = torch.jit.load(pt_path).to(self.test_device) + self.last_w = 416 + self.last_h = 416 + self.grids = None + + @torch.no_grad() + def detect(self, srcimg): + # t0=time.time() + + h0, w0 = srcimg.shape[:2] # orig hw + r = self.inpSize / min(h0, w0) # resize image to img_size + h1 = int(h0*r+31)//32*32 + w1 = int(w0*r+31)//32*32 + + img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR) + + # Convert + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB + + # Run inference + img = torch.from_numpy(img).to(self.test_device).permute(2,0,1) + img = img.float()/255 # uint8 to fp16/32 0-1 + if img.ndimension() == 3: + img = img.unsqueeze(0) + + # Inference + if h1 != self.last_h or w1 != self.last_w or self.grids is None: + grids = [] + for scale in [8,16,32]: + ny = h1//scale + nx = w1//scale + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float() + grids.append(grid.to(self.test_device)) + self.grids = grids + self.last_w = w1 + self.last_h = h1 + + pred = self.model(img, self.grids).cpu() + + # Apply NMS + det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0] + # Process detections + # det = pred[0] + bboxes = np.zeros((det.shape[0], 4)) + kpss = np.zeros((det.shape[0], 5, 2)) + scores = np.zeros((det.shape[0])) + # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh + # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks + det = det.cpu().numpy() + + for j in range(det.shape[0]): + # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy() + bboxes[j, 0] = det[j, 0] * w0/w1 + bboxes[j, 1] = det[j, 1] * h0/h1 + bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0] + bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1] + scores[j] = det[j, 4] + # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy() + kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]]) + # class_num = det[j, 15].cpu().numpy() + # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num) + return bboxes, kpss, scores diff --git a/hyvideo/diffusion/pipelines/__init__.py b/hyvideo/diffusion/pipelines/__init__.py index e44cb61..d684744 100644 --- a/hyvideo/diffusion/pipelines/__init__.py +++ b/hyvideo/diffusion/pipelines/__init__.py @@ -1 +1,2 @@ from .pipeline_hunyuan_video import HunyuanVideoPipeline +from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline \ No newline at end of file diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py b/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py index fa88145..f7339e5 100644 --- a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py @@ -1142,7 +1142,7 @@ class HunyuanVideoPipeline(DiffusionPipeline): target_dtype = PRECISION_TO_TYPE[precision] autocast_enabled = target_dtype != torch.float32 and not disable_autocast - vae_dtype = PRECISION_TO_TYPE[vae_precision] + vae_dtype = self.vae._model_dtype # PRECISION_TO_TYPE[vae_precision] vae_autocast_enabled = vae_dtype != torch.float32 and not disable_autocast # 7. Denoising loop @@ -1262,6 +1262,7 @@ class HunyuanVideoPipeline(DiffusionPipeline): guidance=guidance_expand, pipeline=self, x_id=j, + step_no=i, callback = callback, ) if self._interrupt: @@ -1290,6 +1291,7 @@ class HunyuanVideoPipeline(DiffusionPipeline): freqs_sin=freqs_cis[1], # [seqlen, head_dim] guidance=guidance_expand, pipeline=self, + step_no=i, callback = callback, ) if self._interrupt: @@ -1404,7 +1406,6 @@ class HunyuanVideoPipeline(DiffusionPipeline): else: image = latents - image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().float() diff --git a/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py b/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py new file mode 100644 index 0000000..358615b --- /dev/null +++ b/hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py @@ -0,0 +1,1359 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +import inspect +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import numpy as np +import torch +from packaging import version +from diffusers.utils import BaseOutput +from dataclasses import dataclass +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from hyvideo.constants import PRECISION_TO_TYPE +from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D +from hyvideo.text_encoder import TextEncoder +from einops import rearrange +from ...modules import HYVideoDiffusionTransformer + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + + +class HunyuanVideoAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`TextEncoder`]): + Frozen text-encoder. + text_encoder_2 ([`TextEncoder`]): + Frozen text-encoder_2. + transformer ([`HYVideoDiffusionTransformer`]): + A `HYVideoDiffusionTransformer` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = ["text_encoder_2"] + _exclude_from_cpu_offload = ["transformer"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: TextEncoder, + transformer: HYVideoDiffusionTransformer, + scheduler: KarrasDiffusionSchedulers, + text_encoder_2: Optional[TextEncoder] = None, + progress_bar_config: Dict[str, Any] = None, + args=None, + ): + super().__init__() + + # ========================================================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, '_progress_bar_config'): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + + self.args = args + # ========================================================================================== + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2 + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def encode_prompt( + self, + prompt, + name, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + pixel_value_llava: Optional[torch.Tensor] = None, + uncond_pixel_value_llava: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + text_encoder: Optional[TextEncoder] = None, + data_type: Optional[str] = "image", + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + pixel_value_llava (`torch.Tensor`, *optional*): + The image tensor for llava. + uncond_pixel_value_llava (`torch.Tensor`, *optional*): + The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + attention_mask (`torch.Tensor`, *optional*): + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_attention_mask (`torch.Tensor`, *optional*): + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + text_encoder (TextEncoder, *optional*): + """ + if text_encoder is None: + text_encoder = self.text_encoder + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder.model, lora_scale) + else: + scale_lora_layers(text_encoder.model, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name) + + if pixel_value_llava is not None: + text_inputs['pixel_value_llava'] = pixel_value_llava + text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1) + + if clip_skip is None: + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) + prompt_embeds = prompt_outputs.hidden_state + else: + prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds) + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_videos_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len) + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.transformer is not None: + prompt_embeds_dtype = self.transformer.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer) + uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type) + if uncond_pixel_value_llava is not None: + uncond_input['pixel_value_llava'] = uncond_pixel_value_llava + uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1) + + negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type) + negative_prompt_embeds = negative_prompt_outputs.hidden_state + + negative_attention_mask = negative_prompt_outputs.attention_mask + if negative_attention_mask is not None: + negative_attention_mask = negative_attention_mask.to(device) + _, seq_len = negative_attention_mask.shape + negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt) + negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if negative_prompt_embeds.ndim == 2: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + else: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + if text_encoder is not None: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(text_encoder.model, lora_scale) + + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def encode_prompt_audio_text_base( + self, + prompt, + uncond_prompt, + pixel_value_llava, + uncond_pixel_value_llava, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + text_encoder: Optional[TextEncoder] = None, + data_type: Optional[str] = "image", + name = "person" + ): + if text_encoder is None: + text_encoder = self.text_encoder + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(text_encoder.model, lora_scale) + else: + scale_lora_layers(text_encoder.model, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds = None + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer) + text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name) # data_type: video, text_inputs: {'input_ids', 'attention_mask'} + + text_keys = ['input_ids', 'attention_mask'] + + if pixel_value_llava is not None: + text_inputs['pixel_value_llava'] = pixel_value_llava + text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575)).to(text_inputs['attention_mask'])], dim=1) + + + if clip_skip is None: + prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) + prompt_embeds = prompt_outputs.hidden_state + else: + prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds) + + attention_mask = prompt_outputs.attention_mask + if attention_mask is not None: + attention_mask = attention_mask.to(device) + bs_embed, seq_len = attention_mask.shape + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, seq_len) + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if prompt_embeds.ndim == 2: + bs_embed, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, -1) + else: + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer) + # max_length = prompt_embeds.shape[1] + uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type, name=name) + + # if hasattr(text_encoder.model.config, "use_attention_mask") and text_encoder.model.config.use_attention_mask: + # attention_mask = uncond_input.attention_mask.to(device) + # else: + # attention_mask = None + if uncond_pixel_value_llava is not None: + uncond_input['pixel_value_llava'] = uncond_pixel_value_llava + uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575)).to(uncond_input['attention_mask'])], dim=1) + + negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type) + negative_prompt_embeds = negative_prompt_outputs.hidden_state + + negative_attention_mask = negative_prompt_outputs.attention_mask + if negative_attention_mask is not None: + negative_attention_mask = negative_attention_mask.to(device) + _, seq_len = negative_attention_mask.shape + negative_attention_mask = negative_attention_mask.repeat(1, num_images_per_prompt) + negative_attention_mask = negative_attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + if negative_prompt_embeds.ndim == 2: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + else: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if text_encoder is not None: + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(text_encoder.model, lora_scale) + + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def decode_latents(self, latents, enable_tiling=True): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + if enable_tiling: + self.vae.enable_tiling() + image = self.vae.decode(latents, return_dict=False)[0] + self.vae.disable_tiling() + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float() + else: image = image.cpu().float() + return image + + def prepare_extra_func_kwargs(self, func, kwargs): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + extra_step_kwargs = {} + + for k, v in kwargs.items(): + accepts = k in set(inspect.signature(func).parameters.keys()) + if accepts: + extra_step_kwargs[k] = v + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + frame, + callback_steps, + pixel_value_llava=None, + uncond_pixel_value_llava=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + vae_ver='88-4c-sd' + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if frame is not None: + if '884' in vae_ver: + if frame!=1 and (frame-1)%4!=0: + raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.') + elif '888' in vae_ver: + if frame!=1 and (frame-1)%8!=0: + raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.') + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if pixel_value_llava is not None and uncond_pixel_value_llava is not None: + if len(pixel_value_llava) != len(uncond_pixel_value_llava): + raise ValueError( + "`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but" + f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`" + f" {len(uncond_pixel_value_llava)}." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps.to(device), num_inference_steps - t_start + + def prepare_latents(self, batch_size, num_channels_latents, height, width, frame, dtype, device, generator, latents=None, ref_latents=None, timestep=None): + shape = ( + batch_size, + num_channels_latents, + frame, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + + if timestep is not None: + init_latents = ref_latents.clone().repeat(1,1,frame,1,1).to(device).to(dtype) + latents = latents + + # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + + ref_latents: Union[torch.Tensor], # [1, 16, 1, h//8, w//8] + # uncond_ref_latents: Union[torch.Tensor], + pixel_value_llava: Union[torch.Tensor], # [1, 3, 336, 336] + uncond_pixel_value_llava: Union[torch.Tensor], + pixel_value_ref, + face_masks: Union[torch.Tensor], # [b f h w] + audio_prompts: Union[torch.Tensor], + uncond_audio_prompts: Union[torch.Tensor], + motion_exp: Union[torch.Tensor], + motion_pose: Union[torch.Tensor], + fps: Union[torch.Tensor], + + height: int, + width: int, + video_length: int, + data_type: str = "video", + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + vae_ver: str = "88-4c-sd", + enable_tiling: bool = False, + n_tokens: Optional[int] = None, + embedded_guidance_scale: Optional[float] = None, + joint_pass = False, + cfg_star_rescale = False, + name = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + video_length (`int`): + The number of frames in the generated video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + + # num_inference_steps = 50 + + # 0. Default height and width to transformer + # height = height or self.transformer.config.sample_size * self.vae_scale_factor + # width = width or self.transformer.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + transformer = self.transformer + + if transformer.enable_teacache: + teacache_multiplier = transformer.teacache_multiplier + transformer.accumulated_rel_l1_distance = 0 + transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + video_length, + callback_steps, + pixel_value_llava, + uncond_pixel_value_llava, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + vae_ver=vae_ver + ) + + self._guidance_scale = guidance_scale + self.start_cfg_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + + # ========== Encode text prompt (image prompt) ========== + prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \ + self.encode_prompt_audio_text_base( + prompt=prompt, + uncond_prompt=negative_prompt, + pixel_value_llava=pixel_value_llava, + uncond_pixel_value_llava=uncond_pixel_value_llava, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + text_encoder=self.text_encoder, + data_type=data_type, + name= name, + # **kwargs + ) + if self.text_encoder_2 is not None: + prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \ + self.encode_prompt_audio_text_base( + prompt=prompt, + uncond_prompt=negative_prompt, + pixel_value_llava=None, + uncond_pixel_value_llava=None, + device=device, + num_images_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + text_encoder=self.text_encoder_2, + # **kwargs + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_mask_2 = None + negative_prompt_mask_2 = None + + if self.transformer.mixed_precision: + latent_dtype = torch.float32 + else: + latent_dtype = torch.bfloat16 + if prompt_embeds != None: + prompt_embeds = prompt_embeds.to(torch.bfloat16) + if negative_prompt_embeds != None: + negative_prompt_embeds = negative_prompt_embeds.to(torch.bfloat16) + if prompt_embeds_2 != None: + prompt_embeds_2 = prompt_embeds_2.to(torch.bfloat16) + if negative_prompt_embeds_2 != None: + negative_prompt_embeds_2 = negative_prompt_embeds_2.to(torch.bfloat16) + if audio_prompts != None: + audio_prompts = audio_prompts.to(torch.bfloat16) + if face_masks!= None: + face_masks = face_masks.to(torch.bfloat16) + if ref_latents != None: + ref_latents = ref_latents.to(torch.bfloat16) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds]) + if prompt_mask is not None: + prompt_mask_input = torch.cat([negative_prompt_mask, prompt_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2_input = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + if prompt_mask_2 is not None: + prompt_mask_2_input = torch.cat([negative_prompt_mask_2, prompt_mask_2]) + + if self.do_classifier_free_guidance and ref_latents != None: + ref_latents = torch.cat([ref_latents, ref_latents], dim=0) + + + # 4. Prepare timesteps + extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.set_timesteps, {"n_tokens": n_tokens} + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs, + ) + + video_length = audio_prompts.shape[1] // 4 * 4 + 1 + if "884" in vae_ver: + video_length = (video_length - 1) // 4 + 1 + elif "888" in vae_ver: + video_length = (video_length - 1) // 8 + 1 + else: + video_length = video_length + + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + infer_length = (audio_prompts.shape[1] // 128 + 1) * 32 + 1 + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + infer_length, + latent_dtype, #prompt_embeds.dtype, + device, + generator, + latents, + ref_latents[-1:] if ref_latents != None else None, + timesteps[:1] + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, {"generator": generator, "eta": eta}, + ) + + vae_precision = "fp16" # torch.float16 + precision = "bf16" # torch.bfloat16 + disable_autocast = True + + target_dtype = PRECISION_TO_TYPE[precision] + autocast_enabled = (target_dtype != torch.float32) and not disable_autocast + vae_dtype = self.vae._model_dtype #PRECISION_TO_TYPE[vae_precision] + vae_autocast_enabled = (vae_dtype != torch.float32) and not disable_autocast + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + latents_all = latents.clone() + pad_audio_length = (audio_prompts.shape[1] // 128 + 1) * 128 + 4 - audio_prompts.shape[1] + audio_prompts_all = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :pad_audio_length])], dim=1) + + + shift = 0 + shift_offset = 10 + frames_per_batch = 33 + self.cache_tensor = None + + """ If the total length is shorter than 129, shift is not required """ + if video_length == 33 or infer_length == 33: + infer_length = 33 + shift_offset = 0 + latents_all = latents_all[:, :, :33] + audio_prompts_all = audio_prompts_all[:, :132] + joint_pass = joint_pass or not self.do_classifier_free_guidance + + if callback != None: + callback(-1, None, True, override_num_inference_steps = num_inference_steps) + + latent_items = 2 if self.do_classifier_free_guidance else 1 + + fps = torch.from_numpy(np.array(fps)).unsqueeze(0).to(dtype=torch.float16) + + if self._interrupt: + return [None] + + if transformer.enable_teacache: + cache_size = round( infer_length / frames_per_batch ) + transformer.previous_residual = [None] * latent_items + cache_all_previous_residual = [None] * latent_items + cache_all_previous_modulated_input = None + cache_should_calc = [True] * cache_size + cache_accumulated_rel_l1_distance = [0.] * cache_size + cache_teacache_skipped_steps = [0] * cache_size + + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # init + pred_latents = torch.zeros_like( + latents_all, + dtype=latents_all.dtype, + ) + counter = torch.zeros( + (latents_all.shape[0], latents_all.shape[1], infer_length, 1, 1), + dtype=latents_all.dtype, + ).to(device=latents_all.device) + + cache_slot_no = 0 + for index_start in range(0, infer_length, frames_per_batch): + self.scheduler._step_index = None + + index_start = index_start - shift + idx_list = [ii % latents_all.shape[2] for ii in range(index_start, index_start + frames_per_batch)] + latents = latents_all[:, :, idx_list].clone() + + idx_list_audio = [ii % audio_prompts_all.shape[1] for ii in range(index_start * 4, (index_start + frames_per_batch) * 4 - 3)] + audio_prompts = audio_prompts_all[:, idx_list_audio].clone() + + # expand the latents if we are doing classifier free guidance + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([latents] * 2) + else: + latent_model_input = latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + embedded_hw = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * 3072 + img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1) + img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3] + + if transformer.enable_teacache and cache_size > 1: + for l in range(latent_items): + if cache_all_previous_residual[l] != None: + bsz = cache_all_previous_residual[l].shape[0] + transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + if cache_all_previous_modulated_input != None: + transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072) + transformer.should_calc = cache_should_calc[cache_slot_no] + transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no] + transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no] + + + if self.do_classifier_free_guidance: + if i < num_inference_steps * 0.2 : + self._guidance_scale = (1 - i / len(timesteps)) * (self.start_cfg_scale - 2) + 2 + audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0) + face_masks_input = torch.cat([face_masks * 0.6] * 2, dim=0) + else: + # define 10-50 step cfg + self._guidance_scale = (1 - i / len(timesteps)) * (6.5 - 3.5) + 3.5 # 5-2 +2 + + prompt_embeds_input = torch.cat([prompt_embeds, prompt_embeds]) + if prompt_mask is not None: + prompt_mask_input = torch.cat([prompt_mask, prompt_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2_input = torch.cat([prompt_embeds_2, prompt_embeds_2]) + if prompt_mask_2 is not None: + prompt_mask_2_input = torch.cat([prompt_mask_2, prompt_mask_2]) + audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0) + face_masks_input = torch.cat([face_masks] * 2, dim=0) + + motion_exp_input = torch.cat([motion_exp] * 2, dim=0) + motion_pose_input = torch.cat([motion_pose] * 2, dim=0) + fps_input = torch.cat([fps] * 2, dim=0) + + else: + audio_prompts_input = audio_prompts + face_masks_input = face_masks + motion_exp_input = motion_exp + motion_pose_input = motion_pose + fps_input = fps + + t_expand = t.repeat(latent_model_input.shape[0]) + guidance_expand = None + + with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled): + additional_kwargs = { + "pipeline": self, + "step_no": i, + } + if joint_pass: + additional_kwargs.update({ + "motion_exp": motion_exp_input, + "motion_pose": motion_pose_input, + "fps": fps_input, + "audio_prompts": audio_prompts_input, + "face_mask": face_masks_input + }) + noise_pred = self.transformer(latent_model_input, t_expand, ref_latents=ref_latents, text_states=prompt_embeds_input, text_mask=prompt_mask_input, text_states_2=prompt_embeds_2_input, freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, **additional_kwargs,) + if self._interrupt: + return [None] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + else: + additional_kwargs.update({ + "motion_exp": motion_exp_input[:1], + "motion_pose": motion_pose_input[:1], + "fps": fps_input[:1], + "audio_prompts": audio_prompts_input[:1], + "face_mask": face_masks_input[:1] + }) + noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds_input[:1], text_mask=prompt_mask_input[:1], text_states_2=prompt_embeds_2_input[:1], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 0, **additional_kwargs,) + if self._interrupt: + return [None] + noise_pred_uncond = noise_pred_uncond[0] + additional_kwargs.update({ + "motion_exp": motion_exp_input[1:], + "motion_pose": motion_pose_input[1:], + "fps": fps_input[1:], + "audio_prompts": audio_prompts_input[1:], + "face_mask": face_masks_input[1:] + }) + noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds_input[1:], text_mask=prompt_mask_input[1:], text_states_2=prompt_embeds_2_input[1:], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 1, **additional_kwargs,) + if self._interrupt: + return [None] + noise_pred_text = noise_pred_text[0] + + # perform guidance + if self.do_classifier_free_guidance: + if cfg_star_rescale: + batch_size = 1 + positive_flat = noise_pred_text.view(batch_size, -1) + negative_flat = noise_pred_uncond.view(batch_size, -1) + dot_product = torch.sum( + positive_flat * negative_flat, dim=1, keepdim=True + ) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + positive_flat, negative_flat = None, None + alpha = dot_product / squared_norm + noise_pred_uncond *= alpha + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred_text, noise_pred_uncond = None, None + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + noise_pred = None + + latents = latents.to(torch.bfloat16) + for iii in range(frames_per_batch): + p = (index_start + iii) % pred_latents.shape[2] + pred_latents[:, :, p] += latents[:, :, iii] + counter[:, :, p] += 1 + + if transformer.enable_teacache and cache_size > 1: + for l in range(latent_items): + if transformer.previous_residual[l] != None: + bsz = transformer.previous_residual[l].shape[0] + if cache_all_previous_residual[l] == None: + cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype) + cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw) + + if transformer.previous_modulated_input != None: + if cache_all_previous_modulated_input == None: + cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype) + cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw) + cache_should_calc[cache_slot_no] = transformer.should_calc + cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance + cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps + + cache_slot_no += 1 + + shift += shift_offset + shift = shift % frames_per_batch + pred_latents = pred_latents / counter + latents_all = pred_latents + + if callback is not None: + callback(i, latents_all.squeeze(0), False) + + latents = latents_all.float()[:, :, :video_length] + + if not output_type == "latent": + expand_temporal_dim = False + if len(latents.shape) == 4: + if isinstance(self.vae, AutoencoderKLCausal3D): + latents = latents.unsqueeze(2) + expand_temporal_dim = True + elif len(latents.shape) == 5: + pass + else: + raise ValueError( + f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.") + + if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: + latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor + else: + latents = latents / self.vae.config.scaling_factor + + with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled): + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + if image is None: + return (None, ) + + if expand_temporal_dim or image.shape[2] == 1: + image = image.squeeze(2) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().float() + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + + return HunyuanVideoPipelineOutput(videos=image) diff --git a/hyvideo/hunyuan.py b/hyvideo/hunyuan.py index f753184..e909e2f 100644 --- a/hyvideo/hunyuan.py +++ b/hyvideo/hunyuan.py @@ -5,7 +5,7 @@ import functools from typing import List, Optional, Tuple, Union from pathlib import Path - +from einops import rearrange import torch import torch.distributed as dist from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V @@ -16,11 +16,34 @@ from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_ from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler from hyvideo.diffusion.pipelines import HunyuanVideoPipeline +from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline from PIL import Image import numpy as np import torchvision.transforms as transforms import cv2 from wan.utils.utils import resize_lanczos, calculate_new_dimensions +from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask +from transformers import WhisperModel +from transformers import AutoFeatureExtractor +from hyvideo.data_kits.face_align import AlignImage +import librosa + +def get_audio_feature(feature_extractor, audio_path, duration): + audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000) + assert sampling_rate == 16000 + + audio_features = [] + window = 750*640 + for i in range(0, len(audio_input), window): + audio_feature = feature_extractor(audio_input[i:i+window], + sampling_rate=sampling_rate, + return_tensors="pt", + device="cuda" + ).input_features + audio_features.append(audio_feature) + + audio_features = torch.cat(audio_features, dim=-1) + return audio_features, len(audio_input) // 640 def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): crop_h, crop_w = crop_img.shape[:2] @@ -212,6 +235,14 @@ def patched_llava_forward( image_hidden_states=image_features if pixel_values is not None else None, ) +def adapt_avatar_model(model): + modules_dict= { k: m for k, m in model.named_modules()} + for model_layer, avatar_layer in model.double_stream_map.items(): + module = modules_dict[f"audio_adapter_blocks.{avatar_layer}"] + target = modules_dict[f"double_blocks.{model_layer}"] + setattr(target, "audio_adapter", module ) + delattr(model, "audio_adapter_blocks") + class DataPreprocess(object): def __init__(self): self.llava_size = (336, 336) @@ -223,12 +254,18 @@ class DataPreprocess(object): ] ) - def get_batch(self, image , size): + def get_batch(self, image , size, pad = False): image = np.asarray(image) - llava_item_image = pad_image(image.copy(), self.llava_size) + if pad: + llava_item_image = pad_image(image.copy(), self.llava_size) + else: + llava_item_image = image.copy() uncond_llava_item_image = np.ones_like(llava_item_image) * 255 - cat_item_image = pad_image(image.copy(), size) + if pad: + cat_item_image = pad_image(image.copy(), size) + else: + cat_item_image = image.copy() llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 @@ -243,6 +280,8 @@ class Inference(object): def __init__( self, i2v, + custom, + avatar, enable_cfg, vae, vae_kwargs, @@ -250,9 +289,14 @@ class Inference(object): model, text_encoder_2=None, pipeline=None, + feature_extractor=None, + wav2vec=None, + align_instance=None, device=None, ): self.i2v = i2v + self.custom = custom + self.avatar = avatar self.enable_cfg = enable_cfg self.vae = vae self.vae_kwargs = vae_kwargs @@ -263,8 +307,11 @@ class Inference(object): self.model = model self.pipeline = pipeline - self.device = "cuda" + self.feature_extractor=feature_extractor + self.wav2vec=wav2vec + self.align_instance=align_instance + self.device = "cuda" @classmethod @@ -285,15 +332,21 @@ class Inference(object): i2v_condition_type = None i2v_mode = "i2v" in model_filepath[0] custom = False + avatar = False if i2v_mode: model_id = "HYVideo-T/2" i2v_condition_type = "token_replace" elif "custom" in model_filepath[0]: model_id = "HYVideo-T/2-custom" custom = True + elif "avatar" in model_filepath[0]: + model_id = "HYVideo-T/2-avatar" + text_len = 256 + avatar = True else: model_id = "HYVideo-T/2-cfgdistill" + if i2v_mode and i2v_condition_type == "latent_concat": in_channels = latent_channels * 2 + 1 image_embed_interleave = 2 @@ -323,11 +376,11 @@ class Inference(object): from mmgp import offload # model = Inference.load_state_dict(args, model, model_filepath) - # model_filepath ="c:/temp/hc/mp_rank_00_model_states.pt" + # model_filepath ="c:/temp/avatar/mp_rank_00_model_states.pt" offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning) pass - # offload.save_model(model, "hunyuan_video_custom_720_bf16.safetensors") - # offload.save_model(model, "hunyuan_video_custom_720_quanto_bf16_int8.safetensors", do_quantize= True) + # offload.save_model(model, "hunyuan_video_avatar_720_bf16.safetensors") + # offload.save_model(model, "hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", do_quantize= True) model.mixed_precision = mixed_precision_transformer @@ -338,9 +391,12 @@ class Inference(object): # ============================= Build extra models ======================== # VAE - if custom: + if custom or avatar: vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json" vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors" + # elif avatar: + # vae_configpath = "ckpts/config_vae_avatar.json" + # vae_filepath = "ckpts/vae_avatar.pt" else: vae_configpath = "ckpts/hunyuan_video_VAE_config.json" vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors" @@ -350,6 +406,7 @@ class Inference(object): vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", ) + vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else (torch.float16 if avatar else torch.bfloat16) vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16 vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} enable_cfg = False @@ -359,7 +416,7 @@ class Inference(object): tokenizer = "llm-i2v" prompt_template = "dit-llm-encode-i2v" prompt_template_video = "dit-llm-encode-video-i2v" - elif custom : + elif custom or avatar : text_encoder = "llm-i2v" tokenizer = "llm-i2v" prompt_template = "dit-llm-encode" @@ -411,14 +468,33 @@ class Inference(object): device="cpu", ) + feature_extractor = None + wav2vec = None + align_instance = None + + if avatar: + feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/") + wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32) + wav2vec._model_dtype = torch.float32 + wav2vec.requires_grad_(False) + align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") + align_instance.facedet.model.to("cpu") + + adapt_avatar_model(model) + return cls( i2v=i2v_mode, + custom=custom, + avatar=avatar, enable_cfg = enable_cfg, vae=vae, vae_kwargs=vae_kwargs, text_encoder=text_encoder, text_encoder_2=text_encoder_2, model=model, + feature_extractor=feature_extractor, + wav2vec=wav2vec, + align_instance=align_instance, device=device, ) @@ -428,6 +504,8 @@ class HunyuanVideoSampler(Inference): def __init__( self, i2v, + custom, + avatar, enable_cfg, vae, vae_kwargs, @@ -435,10 +513,15 @@ class HunyuanVideoSampler(Inference): model, text_encoder_2=None, pipeline=None, + feature_extractor=None, + wav2vec=None, + align_instance=None, device=0, ): super().__init__( i2v, + custom, + avatar, enable_cfg, vae, vae_kwargs, @@ -446,12 +529,16 @@ class HunyuanVideoSampler(Inference): model, text_encoder_2=text_encoder_2, pipeline=pipeline, + feature_extractor=feature_extractor, + wav2vec=wav2vec, + align_instance=align_instance, device=device, ) self.i2v_mode = i2v self.enable_cfg = enable_cfg self.pipeline = self.load_diffusion_pipeline( + avatar = self.avatar, vae=self.vae, text_encoder=self.text_encoder, text_encoder_2=self.text_encoder_2, @@ -474,6 +561,7 @@ class HunyuanVideoSampler(Inference): def load_diffusion_pipeline( self, + avatar, vae, text_encoder, text_encoder_2, @@ -491,14 +579,24 @@ class HunyuanVideoSampler(Inference): solver="euler", ) - pipeline = HunyuanVideoPipeline( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - transformer=model, - scheduler=scheduler, - progress_bar_config=progress_bar_config, - ) + if avatar: + pipeline = HunyuanVideoAudioPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + transformer=model, + scheduler=scheduler, + progress_bar_config=progress_bar_config, + ) + else: + pipeline = HunyuanVideoPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + transformer=model, + scheduler=scheduler, + progress_bar_config=progress_bar_config, + ) return pipeline @@ -588,6 +686,8 @@ class HunyuanVideoSampler(Inference): self, input_prompt, input_ref_images = None, + audio_guide = None, + fps = 24, height=192, width=336, frame_num=129, @@ -617,14 +717,12 @@ class HunyuanVideoSampler(Inference): self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"] self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"] self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"] + self.vae.enable_tiling() i2v_mode= self.i2v_mode if not self.enable_cfg: guide_scale=1.0 - - out_dict = dict() - # ======================================================================== # Arguments: seed # ======================================================================== @@ -663,7 +761,6 @@ class HunyuanVideoSampler(Inference): seed_everything(seed) generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] - out_dict["seeds"] = seeds # ======================================================================== # Arguments: target_width, target_height, target_frame_num @@ -681,8 +778,6 @@ class HunyuanVideoSampler(Inference): target_width = align_to(width, 16) target_frame_num = frame_num - out_dict["size"] = (target_height, target_width, target_frame_num) - if input_ref_images != None: # ip_cfg_scale = 3.0 ip_cfg_scale = 0 @@ -769,28 +864,91 @@ class HunyuanVideoSampler(Inference): if input_ref_images == None: freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_riflex) else: - concat_dict = {'mode': 'timecat-w', 'bias': -1} - freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict) + if self.avatar: + w, h = input_ref_images.size + target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) + concat_dict = {'mode': 'timecat', 'bias': -1} + freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) + else: + concat_dict = {'mode': 'timecat-w', 'bias': -1} + freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict) n_tokens = freqs_cos.shape[0] - callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) # ======================================================================== # Pipeline inference # ======================================================================== - start_time = time.time() - - # "pixel_value_llava": llava_item_tensor.unsqueeze(0), - # "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), - # 'pixel_value_ref': cat_item_tensor.unsqueeze(0), + pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None if input_ref_images == None: - pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None name = None else: - pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height)) + pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom) + + ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None + + if audio_guide != None: + if n_prompt == None or len(n_prompt) == 0: + n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" + + uncond_pixel_value_llava = pixel_value_llava.clone() + + pixel_value_ref = pixel_value_ref.unsqueeze(0) + self.align_instance.facedet.model.to("cuda") + face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0) + # iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8) + # image = Image.fromarray(iii) + # image.save("mask.png") + # jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8) + + self.align_instance.facedet.model.to("cpu") + # pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) + + pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1) + pixel_value_ref = pixel_value_ref * 2 - 1 + pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") + + vae_dtype = self.vae.dtype + with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): + ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample() + ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1), ref_latents[:,:, -1:]], dim=2) + pixel_value_ref, pixel_value_ref_for_vae = None, None + + if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: + ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) + else: + ref_latents.mul_(self.vae.config.scaling_factor) + + # out_latents= ref_latents / self.vae.config.scaling_factor + # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] + # image = image.clamp(-1, 1) + # from wan.utils.utils import cache_video + # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) + + + face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), + (ref_latents.shape[-2], + ref_latents.shape[-1]), + mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) + + audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps ) + audio_prompts = audio_input[0] + weight_dtype = audio_prompts.dtype + + motion_pose = np.array([25] * 4) + motion_exp = np.array([30] * 4) + motion_pose = torch.from_numpy(motion_pose).unsqueeze(0) + motion_exp = torch.from_numpy(motion_exp).unsqueeze(0) + audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len) + audio_prompts = audio_prompts.to(self.model.dtype) + if audio_prompts.shape[1] <= 129: + audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1) + else: + audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) + uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) + # target_frame_num = min(target_frame_num, audio_len) samples = self.pipeline( prompt=input_prompt, height=target_height, @@ -803,9 +961,18 @@ class HunyuanVideoSampler(Inference): generator=generator, output_type="pil", name = name, - pixel_value_llava = pixel_value_llava, - uncond_pixel_value_llava=uncond_pixel_value_llava, - pixel_value_ref=pixel_value_ref, + + pixel_value_ref = pixel_value_ref, + ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] + pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] + uncond_pixel_value_llava=uncond_pixel_value_llava, + face_masks=face_masks, # [b f h w] + audio_prompts=audio_prompts, + uncond_audio_prompts=uncond_audio_prompts, + motion_exp=motion_exp, + motion_pose=motion_pose, + fps= torch.from_numpy(np.array(fps)), + denoise_strength=denoise_strength, ip_cfg_scale=ip_cfg_scale, freqs_cis=(freqs_cos, freqs_sin), @@ -825,9 +992,9 @@ class HunyuanVideoSampler(Inference): callback = callback, callback_steps = callback_steps, )[0] - gen_time = time.time() - start_time + if samples == None: return None - samples = samples.sub_(0.5).mul_(2).squeeze(0) + samples = samples.squeeze(0) return samples diff --git a/hyvideo/modules/audio_adapters.py b/hyvideo/modules/audio_adapters.py new file mode 100644 index 0000000..3fdef93 --- /dev/null +++ b/hyvideo/modules/audio_adapters.py @@ -0,0 +1,220 @@ +""" +This module provides the implementation of an Audio Projection Model, which is designed for +audio processing tasks. The model takes audio embeddings as input and outputs context tokens +that can be used for various downstream applications, such as audio analysis or synthesis. + +The AudioProjModel class is based on the ModelMixin class from the diffusers library, which +provides a foundation for building custom models. This implementation includes multiple linear +layers with ReLU activation functions and a LayerNorm for normalization. + +Key Features: +- Audio embedding input with flexible sequence length and block structure. +- Multiple linear layers for feature transformation. +- ReLU activation for non-linear transformation. +- LayerNorm for stabilizing and speeding up training. +- Rearrangement of input embeddings to match the model's expected input shape. +- Customizable number of blocks, channels, and context tokens for adaptability. + +The module is structured to be easily integrated into larger systems or used as a standalone +component for audio feature extraction and processing. + +Classes: +- AudioProjModel: A class representing the audio projection model with configurable parameters. + +Functions: +- (none) + +Dependencies: +- torch: For tensor operations and neural network components. +- diffusers: For the ModelMixin base class. +- einops: For tensor rearrangement operations. + +""" + +import torch +from diffusers import ModelMixin +from einops import rearrange + +import math +import torch.nn as nn + +class AudioProjNet2(ModelMixin): + """Audio Projection Model + + This class defines an audio projection model that takes audio embeddings as input + and produces context tokens as output. The model is based on the ModelMixin class + and consists of multiple linear layers and activation functions. It can be used + for various audio processing tasks. + + Attributes: + seq_len (int): The length of the audio sequence. + blocks (int): The number of blocks in the audio projection model. + channels (int): The number of channels in the audio projection model. + intermediate_dim (int): The intermediate dimension of the model. + context_tokens (int): The number of context tokens in the output. + output_dim (int): The output dimension of the context tokens. + + Methods: + __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768): + Initializes the AudioProjModel with the given parameters. + forward(self, audio_embeds): + Defines the forward pass for the AudioProjModel. + Parameters: + audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels). + Returns: + context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim). + + """ + + def __init__( + self, + seq_len=5, + blocks=12, # add a new parameter blocks + channels=768, # add a new parameter channels + intermediate_dim=512, + output_dim=768, + context_tokens=4, + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = ( + seq_len * blocks * channels + ) + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.output_dim = output_dim + + # define multiple linear layers + self.proj1 = nn.Linear(self.input_dim, intermediate_dim) + self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) + self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) + + self.norm = nn.LayerNorm(output_dim) + + + def forward(self, audio_embeds): + + video_length = audio_embeds.shape[1] + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds = torch.relu(self.proj2(audio_embeds)) + + context_tokens = self.proj3(audio_embeds).reshape( + batch_size, self.context_tokens, self.output_dim + ) + context_tokens = self.norm(context_tokens) + out_all = rearrange( + context_tokens, "(bz f) m c -> bz f m c", f=video_length + ) + + return out_all + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttentionCA(nn.Module): + def __init__(self, *, dim=3072, dim_head=1024, heads=33): + super().__init__() + self.scale = dim_head ** -0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head #* heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + import torch.nn.init as init + init.zeros_(self.to_out.weight) + if self.to_out.bias is not None: + init.zeros_(self.to_out.bias) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, t, aa, D) + latent (torch.Tensor): latent features + shape (b, t, hw, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + # print("latents shape: ", latents.shape) + # print("x shape: ", x.shape) + q = self.to_q(latents) + k, v = self.to_kv(x).chunk(2, dim=-1) + + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + # out = out.permute(0, 2, 1, 3) + return self.to_out(out) + #def forward(self, x, latents): + # """ + # Args: + # x (torch.Tensor): image features + # shape (b, t, aa, D) + # latent (torch.Tensor): latent features + # shape (b, t, hw, D) + # """ + # if get_sequence_parallel_state(): + # sp_size = nccl_info.sp_size + # sp_rank = nccl_info.rank_within_group + # print("rank:", latents.shape, sp_size, sp_rank) + # latents = torch.chunk(latents, sp_size, dim=1)[sp_rank] + + # x = self.norm1(x) + # latents = self.norm2(latents) + # # print("latents shape: ", latents.shape) + # # print("x shape: ", x.shape) + # q = self.to_q(latents) + # k, v = self.to_kv(x).chunk(2, dim=-1) + + # # print("q, k, v: ", q.shape, k.shape, v.shape) + + # # attention + # #scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + # #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + # #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + # #out = weight @ v + # def shrink_head(encoder_state, dim): + # local_heads = encoder_state.shape[dim] // nccl_info.sp_size + # return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads) + + # if get_sequence_parallel_state(): + # # batch_size, seq_len, attn_heads, head_dim + # q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128] + # k = shrink_head(k ,dim=2) + # v = shrink_head(v ,dim=2) + # qkv = torch.stack([query, key, value], dim=2) + # attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None) + # # out = out.permute(0, 2, 1, 3) + # #b, s, a, d = attn.shape + # #attn = attn.reshape(b, s, -1) + # + # out = self.to_out(attn) + # if get_sequence_parallel_state(): + # out = all_gather(out, dim=1) + # return out diff --git a/hyvideo/modules/embed_layers.py b/hyvideo/modules/embed_layers.py index 3d65ed1..f4f6785 100644 --- a/hyvideo/modules/embed_layers.py +++ b/hyvideo/modules/embed_layers.py @@ -53,10 +53,11 @@ class PatchEmbed(nn.Module): def forward(self, x): x = self.proj(x) + shape = x.shape if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) - return x + return x, shape class TextProjection(nn.Module): diff --git a/hyvideo/modules/models.py b/hyvideo/modules/models.py index ef3f766..5e76a2a 100644 --- a/hyvideo/modules/models.py +++ b/hyvideo/modules/models.py @@ -19,6 +19,7 @@ from .token_refiner import SingleTokenRefiner import numpy as np from mmgp import offload from wan.modules.attention import pay_attention +from .audio_adapters import AudioProjNet2, PerceiverAttentionCA def get_linear_split_map(): hidden_size = 3072 @@ -589,7 +590,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): use_attention_mask: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - attention_mode: Optional[str] = "sdpa" + attention_mode: Optional[str] = "sdpa", + avatar = False, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -708,6 +710,45 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): get_activation_layer("silu"), **factory_kwargs, ) + avatar_audio = avatar + if avatar_audio: + self.ref_in = PatchEmbed( + self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs + ) + + # -------------------- audio_proj_model -------------------- + self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4) + + # -------------------- motion-embeder -------------------- + self.motion_exp = TimestepEmbedder( + self.hidden_size // 4, + get_activation_layer("silu"), + **factory_kwargs + ) + self.motion_pose = TimestepEmbedder( + self.hidden_size // 4, + get_activation_layer("silu"), + **factory_kwargs + ) + + self.fps_proj = TimestepEmbedder( + self.hidden_size, + get_activation_layer("silu"), + **factory_kwargs + ) + + self.before_proj = nn.Linear(self.hidden_size, self.hidden_size) + + # -------------------- audio_insert_model -------------------- + self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] + self.single_stream_list = [] + self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)} + self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)} + + self.audio_adapter_blocks = nn.ModuleList([ + PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list)) + ]) + def lock_layers_dtypes(self, dtype = torch.float32): @@ -750,11 +791,17 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. pipeline=None, x_id = 0, + step_no = 0, callback = None, + audio_prompts = None, + motion_exp = None, + motion_pose = None, + fps = None, + face_mask = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: img = x - batch_no, _, ot, oh, ow = x.shape + bsz, _, ot, oh, ow = x.shape del x txt = text_states tt, th, tw = ( @@ -765,6 +812,17 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): # Prepare modulation vectors. vec = self.time_in(t) + if motion_exp != None: + vec += self.motion_exp(motion_exp.view(-1)).view(bsz, -1) # (b, 3072) + if motion_pose != None: + vec += self.motion_pose(motion_pose.view(-1)).view(bsz, -1) # (b, 3072) + if fps != None: + vec += self.fps_proj(fps) # (b, 3072) + if audio_prompts != None: + audio_feature_all = self.audio_proj(audio_prompts) + audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1) + audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072) + audio_feature_all = None if self.i2v_condition_type == "token_replace": token_replace_t = torch.zeros_like(t) @@ -777,7 +835,6 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): # token_replace_mask_txt = None # text modulation - # vec = vec + self.vector_in(text_states_2) vec_2 = self.vector_in(text_states_2) del text_states_2 vec += vec_2 @@ -793,12 +850,17 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): ) # our timestep_embedding is merged into guidance_in(TimestepEmbedder) - vec = vec + self.guidance_in(guidance) + vec += self.guidance_in(guidance) # Embed image and text. - img = self.img_in(img) - if ref_latents != None: - ref_latents = self.img_in(ref_latents) + img, shape_mask = self.img_in(img) + if audio_prompts != None: + ref_latents_first = ref_latents[:, :, :1].clone() + ref_latents,_ = self.ref_in(ref_latents) + ref_latents_first,_ = self.img_in(ref_latents_first) + elif ref_latents != None: + ref_latents, _ = self.img_in(ref_latents) + if self.text_projection == "linear": txt = self.txt_in(txt) elif self.text_projection == "single_refiner": @@ -808,7 +870,18 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): f"Unsupported text_projection: {self.text_projection}" ) - if ref_latents == None: + if audio_prompts != None: + img += self.before_proj(ref_latents) + ref_length = ref_latents_first.shape[-2] # [b s c] + img = torch.cat([ref_latents_first, img], dim=-2) # t c + img_len = img.shape[1] + mask_len = img_len - ref_length + if face_mask.shape[2] == 1: + face_mask = face_mask.repeat(1,1,ot,1,1) # repeat if number of mask frame is 1 + face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest") + # face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img) + face_mask = face_mask.view(-1,mask_len,1).type_as(img) + elif ref_latents == None: ref_length = None else: ref_length = ref_latents.shape[-2] @@ -828,13 +901,13 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): if x_id == 0: self.should_calc = True inp = img[0:1] - vec_ = vec + vec_ = vec[0:1] ( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) normed_inp = self.double_blocks[0].img_norm1(inp) normed_inp = normed_inp.to(torch.bfloat16) modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale ) del normed_inp, img_mod1_shift, img_mod1_scale - if self.teacache_counter <= self.teacache_start_step or self.teacache_counter == self.num_steps-1: + if step_no <= self.teacache_start_step or step_no == self.num_steps-1: self.accumulated_rel_l1_distance = 0 else: coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] @@ -846,9 +919,6 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): else: self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp - self.teacache_counter += 1 - if self.teacache_counter == self.num_steps: - self.teacache_counter = 0 else: self.should_calc = True @@ -859,7 +929,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): self.previous_residual[x_id] = None ori_img = img[0:1].clone() # --------------------- Pass through DiT blocks ------------------------ - for _, block in enumerate(self.double_blocks): + for layer_num, block in enumerate(self.double_blocks): for i in range(len(img)): if callback != None: callback(-1, None, False, True) @@ -880,6 +950,16 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): img[i], txt[i] = block(*double_block_args) double_block_args = None + # insert audio feature to img + if audio_prompts != None: + audio_adapter = getattr(self.double_blocks[layer_num], "audio_adapter", None) + if audio_adapter != None: + real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072) + real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072) + real_img *= face_mask[i:i+1] + img[i:i+1, ref_length:] += real_img + real_img = None + for _, block in enumerate(self.single_blocks): for i in range(len(img)): @@ -932,6 +1012,7 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): img = torch.cat(img_list) img_list = None + # img = self.unpatchify(img, tt, th, tw) img = self.unpatchify(img, tt, th, tw) return img @@ -1015,5 +1096,14 @@ HUNYUAN_VIDEO_CONFIG = { "heads_num": 24, "mlp_width_ratio": 4, }, + 'HYVideo-T/2-avatar': { # 9.0B / 12.5B + 'mm_double_blocks_depth': 20, + 'mm_single_blocks_depth': 40, + 'rope_dim_list': [16, 56, 56], + 'hidden_size': 3072, + 'heads_num': 24, + 'mlp_width_ratio': 4, + 'avatar': True, + }, } \ No newline at end of file diff --git a/wgp.py b/wgp.py index 45e0e50..1ae7ab2 100644 --- a/wgp.py +++ b/wgp.py @@ -190,15 +190,19 @@ def process_prompt_and_add_tasks(state, model_choice): if video_length > sliding_window_size: gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated") - if "phantom" in model_filename or "hunyuan_video_custom" in model_filename: + if "phantom" in model_filename or "hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename: image_refs = inputs["image_refs"] - + audio_guide = inputs["audio_guide"] if image_refs == None : gr.Info("You must provide an Image Reference") return - if len(image_refs) > 1 and "hunyuan_video_custom" in model_filename: - gr.Info("Only one Image Reference (a person) is supported for the moment by Hunyuan Custom") + if len(image_refs) > 1 and ("hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename): + gr.Info("Only one Image Reference (a person) is supported for the moment by Hunyuan Custom / Avatar") return + if audio_guide == None and "hunyuan_video_avatar" in model_filename: + gr.Info("You must provide an audio file") + return + if any(isinstance(image[0], str) for image in image_refs) : gr.Info("Reference Image should be an Image") return @@ -1539,7 +1543,9 @@ wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/w ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"] hunyuan_choices= ["ckpts/hunyuan_video_720_bf16.safetensors", "ckpts/hunyuan_video_720_quanto_int8.safetensors", "ckpts/hunyuan_video_i2v_720_bf16v2.safetensors", "ckpts/hunyuan_video_i2v_720_quanto_int8v2.safetensors", - "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors" ] + "ckpts/hunyuan_video_custom_720_bf16.safetensors", "ckpts/hunyuan_video_custom_720_quanto_bf16_int8.safetensors", + "ckpts/hunyuan_video_avatar_720_bf16.safetensors", "ckpts/hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", + ] transformer_choices = wan_choices_t2v + wan_choices_i2v + ltxv_choices + hunyuan_choices def get_dependent_models(model_filename, quantization, dtype_policy ): @@ -1549,12 +1555,13 @@ def get_dependent_models(model_filename, quantization, dtype_policy ): return [get_model_filename("ltxv_13B", quantization, dtype_policy)] else: return [] -model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom"] +model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_avatar"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B", "moviigen" :"moviigen", - "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "fantasy" : "fantasy", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom" } + "phantom_1.3B" : "phantom_1.3B", "phantom_14B" : "phantom_14B", "fantasy" : "fantasy", "ltxv_13B" : "ltxv_0.9.7_13B_dev", "ltxv_13B_distilled" : "ltxv_0.9.7_13B_distilled", + "hunyuan" : "hunyuan_video_720", "hunyuan_i2v" : "hunyuan_video_i2v_720", "hunyuan_custom" : "hunyuan_video_custom", "hunyuan_avatar" : "hunyuan_video_avatar" } def get_model_type(model_filename): @@ -1639,7 +1646,10 @@ def get_model_name(model_filename, description_container = [""]): description = "A good looking image 2 video model, but not so good in prompt adherence." elif "hunyuan_video_custom" in model_filename: model_name = "Hunyuan Video Custom 720p 13B" - description = "The Hunyuan Video Custom model is proably the best model to transfer people (only people for the momment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps." + description = "The Hunyuan Video Custom model is probably the best model to transfer people (only people for the momment) as it is quite good to keep their identity. However it is slow as to get good results, you need to generate 720p videos with 30 steps." + elif "hunyuan_video_avatar" in model_filename: + model_name = "Hunyuan Video Avatar 720p 13B" + description = "With the Hunyuan Video Avatar model you can animate a person based on the content of an audio input. Please note that the video generator works by processing 128 frames segment at a time (even if you ask less). The good news is that it will concatenate multiple segments for long video generation (max 3 segments recommended as the quality will get worse)." else: model_name = "Wan2.1 text2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" @@ -1758,7 +1768,14 @@ def get_default_settings(filename): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, - "resolution": "1280x720" + "resolution": "1280x720", + }) + elif get_model_type(filename) in ("hunyuan_avatar"): + ui_defaults.update({ + "guidance_scale": 7.5, + "flow_shift": 5, + "tea_cache_start_step_perc": 25, + "video_length": 129, }) elif get_model_type(filename) in ("vace_14B"): ui_defaults.update({ @@ -1954,8 +1971,13 @@ def download_models(transformer_filename): text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization) model_def = { "repoId" : "DeepBeepMeep/HunyuanVideo", - "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "" ], - "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(transformer_filename) ] + "sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ], + "fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) , + ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"], + ["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"], + ["detface.pt"], + [ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(transformer_filename) + ] } else: @@ -2122,6 +2144,14 @@ def load_hunyuan_model(model_filename, quantizeTransformer = False, dtype = torc pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae } + if hunyuan_model.wav2vec != None: + pipe["wav2vec"] = hunyuan_model.wav2vec + + + # if hunyuan_model.align_instance != None: + # pipe["align_instance"] = hunyuan_model.align_instance.facedet.model + + from hyvideo.modules.models import get_linear_split_map split_linear_modules_map = get_linear_split_map() @@ -2818,9 +2848,13 @@ def generate_video( hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_custom = "hunyuan_video_custom" in model_filename + hunyuan_avatar = "hunyuan_video_avatar" in model_filename + fantasy = "fantasy" in model_filename if diffusion_forcing or hunyuan_t2v or hunyuan_i2v or hunyuan_custom: fps = 24 - elif audio_guide != None: + elif hunyuan_avatar: + fps = 25 + elif fantasy: fps = 23 elif ltxv: fps = 30 @@ -2829,11 +2863,13 @@ def generate_video( latent_size = 8 if ltxv else 4 original_image_refs = image_refs - if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or vace): - send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) + if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): + if hunyuan_avatar: remove_background_images_ref = 0 + if remove_background_images_ref > 0: + send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from wan.utils.utils import resize_and_remove_background - image_refs = resize_and_remove_background(image_refs, width, height, remove_background_images_ref, fit_into_canvas= not vace) + image_refs = resize_and_remove_background(image_refs, width, height, remove_background_images_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later update_task_thumbnails(task, locals()) send_cmd("output") @@ -2866,13 +2902,14 @@ def generate_video( audio_proj_split = None audio_scale = None audio_context_lens = None - if audio_guide != None: + if (fantasy or hunyuan_avatar) and audio_guide != None: from fantasytalking.infer import parse_audio import librosa duration = librosa.get_duration(path=audio_guide) - current_video_length = min(int(fps * duration // 4) * 4 + 5, current_video_length) - audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= current_video_length, fps= fps, device= processing_device ) - audio_scale = 1.0 + current_video_length = min(int(fps * duration // 4) * 4 + 5, current_video_length) + if fantasy: + audio_proj_split, audio_context_lens = parse_audio(audio_guide, num_frames= current_video_length, fps= fps, device= processing_device ) + audio_scale = 1.0 import random if seed == None or seed <0: @@ -2990,7 +3027,7 @@ def generate_video( if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - if hunyuan_custom: + if hunyuan_custom or hunyuan_avatar: src_ref_images = image_refs elif phantom: src_ref_images = image_refs.copy() if image_refs != None else None @@ -3098,6 +3135,7 @@ def generate_video( cfg_star_switch = cfg_star_switch, cfg_zero_step = cfg_zero_step, audio_cfg_scale= audio_guidance_scale, + audio_guide=audio_guide, audio_proj= audio_proj_split, audio_scale= audio_scale, audio_context_lens= audio_context_lens, @@ -4502,6 +4540,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_video_custom = "hunyuan_video_custom" in model_filename + hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename sliding_window_enabled = vace or diffusion_forcing or ltxv new_line_text = "each new line of prompt will be used for a window" if sliding_window_enabled else "each new line of prompt will generate a new video" @@ -4575,7 +4614,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) - with gr.Column(visible= vace or phantom or hunyuan_video_custom) as video_prompt_column: + with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) with gr.Row(): @@ -4624,14 +4663,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Keep it for first Image (landscape) and remove it for other Images (objects / faces)", 2), ], value=ui_defaults.get("remove_background_images_ref",1), - label="Remove Background of Images References", scale = 3, visible= "I" in video_prompt_type_value + label="Remove Background of Images References", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar ) # remove_background_images_ref = gr.Checkbox(value=ui_defaults.get("remove_background_images_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 ) video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) - audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy ) + audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy or hunyuan_video_avatar ) advanced_prompt = advanced_ui prompt_vars=[] @@ -4720,6 +4759,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) elif fantasy: video_length = gr.Slider(5, 233, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (23 = 1s)", interactive= True) + elif hunyuan_video_avatar: + video_length = gr.Slider(5, 401, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (25 = 1s)", interactive= True) elif hunyuan_t2v or hunyuan_i2v: video_length = gr.Slider(5, 337, value=ui_defaults.get("video_length", 97), step=4, label="Number of frames (24 = 1s)", interactive= True) else: @@ -4809,7 +4850,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) with gr.Tab("Quality", visible = not ltxv) as quality_tab: - with gr.Column(visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_custom) ) as skip_layer_guidance_row: + with gr.Column(visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_custom or hunyuan_video_avatar) ) as skip_layer_guidance_row: gr.Markdown("Skip Layer Guidance (improves video quality)") with gr.Row(): slg_switch = gr.Dropdown(