diff --git a/generate.py b/generate.py index 73f273e..4f114d4 100644 --- a/generate.py +++ b/generate.py @@ -41,6 +41,14 @@ EXAMPLE_PROMPT = { "last_frame": "examples/flf2v_input_last_frame.png", }, + "vace-1.3B": { + "src_ref_images": './bag.jpg,./heben.png', + "prompt": "优雅的女士在精品店仔细挑选包包,她身穿一袭黑色修身连衣裙,搭配珍珠项链,展现出成熟女性的魅力。手中拿着一款复古风格的棕色皮质半月形手提包,正细致地观察其工艺与质地。店内灯光柔和,木质装潢营造出温馨而高级的氛围。中景,侧拍捕捉女士挑选瞬间,展现其品味与气质。" + }, + "vace-14B": { + "src_ref_images": './bag.jpg,./heben.png', + "prompt": "优雅的女士在精品店仔细挑选包包,她身穿一袭黑色修身连衣裙,搭配珍珠项链,展现出成熟女性的魅力。手中拿着一款复古风格的棕色皮质半月形手提包,正细致地观察其工艺与质地。店内灯光柔和,木质装潢营造出温馨而高级的氛围。中景,侧拍捕捉女士挑选瞬间,展现其品味与气质。" + } } @@ -50,6 +58,7 @@ def _validate_args(args): assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + # TODO(wangang.wa): need to be confirmed # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. if args.sample_steps is None: args.sample_steps = 40 if "i2v" in args.task else 50 @@ -141,6 +150,21 @@ def _parse_args(): type=str, default=None, help="The file to save the generated image or video to.") + parser.add_argument( + "--src_video", + type=str, + default=None, + help="The file of the source video. Default None.") + parser.add_argument( + "--src_mask", + type=str, + default=None, + help="The file of the source mask. Default None.") + parser.add_argument( + "--src_ref_images", + type=str, + default=None, + help="The file list of the source reference images. Separated by ','. Default None.") parser.add_argument( "--prompt", type=str, @@ -397,7 +421,7 @@ def generate(args): guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) - else: + elif "flf2v" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.first_frame is None or args.last_frame is None: @@ -457,6 +481,60 @@ def generate(args): seed=args.base_seed, offload_model=args.offload_model ) + elif "vace" in args.task: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.model_name]["prompt"] + args.src_video = EXAMPLE_PROMPT[args.model_name].get("src_video", None) + args.src_mask = EXAMPLE_PROMPT[args.model_name].get("src_mask", None) + args.src_ref_images = EXAMPLE_PROMPT[args.model_name].get("src_ref_images", None) + + logging.info(f"Input prompt: {args.prompt}") + if args.use_prompt_extend and args.use_prompt_extend != 'plain': + logging.info("Extending prompt ...") + if rank == 0: + prompt = prompt_expander.forward(args.prompt) + logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'") + input_prompt = [prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanT2V pipeline.") + wan_vace = wan.WanVace( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video], + [args.src_mask], + [None if args.src_ref_images is None else args.src_ref_images.split(',')], + args.frame_num, SIZE_CONFIGS[args.size], device) + + logging.info(f"Generating video...") + video = wan_vace.generate( + args.prompt, + src_video, + src_mask, + src_ref_images, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + else: + raise ValueError(f"Unkown task type: {args.task}") if rank == 0: if args.save_file is None: diff --git a/wan/__init__.py b/wan/__init__.py index d6c25f4..74c0661 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -2,3 +2,4 @@ from . import configs, distributed, modules from .image2video import WanI2V from .text2video import WanT2V from .first_last_frame2video import WanFLF2V +from .vace import WanVace diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index cccda2f..e7f95d7 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -22,7 +22,9 @@ WAN_CONFIGS = { 't2v-1.3B': t2v_1_3B, 'i2v-14B': i2v_14B, 't2i-14B': t2i_14B, - 'flf2v-14B': flf2v_14B + 'flf2v-14B': flf2v_14B, + 'vace-1.3B': t2v_1_3B, + 'vace-14B': t2v_14B, } SIZE_CONFIGS = { @@ -46,4 +48,6 @@ SUPPORTED_SIZES = { 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), + 'vace-1.3B': ('480*832', '832*480'), + 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') } diff --git a/wan/modules/__init__.py b/wan/modules/__init__.py index f8935bb..a8172ca 100644 --- a/wan/modules/__init__.py +++ b/wan/modules/__init__.py @@ -2,11 +2,13 @@ from .attention import flash_attention from .model import WanModel from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model from .tokenizers import HuggingfaceTokenizer +from .vace_model import VaceWanModel from .vae import WanVAE __all__ = [ 'WanVAE', 'WanModel', + 'VaceWanModel', 'T5Model', 'T5Encoder', 'T5Decoder', diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py new file mode 100644 index 0000000..85425d0 --- /dev/null +++ b/wan/modules/vace_model.py @@ -0,0 +1,237 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d + + +class VaceWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=0 + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + self.after_proj = nn.Linear(self.dim, self.dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class BaseWanAttentionBlock(WanAttentionBlock): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + block_id=None + ): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + self.block_id = block_id + + def forward(self, x, hints, context_scale=1.0, **kwargs): + x = super().forward(x, **kwargs) + if self.block_id is not None: + x = x + hints[self.block_id] * context_scale + return x + + +class VaceWanModel(WanModel): + @register_to_config + def __init__(self, + vace_layers=None, + vace_in_dim=None, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6): + super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim, + num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps) + + self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers + self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim + + assert 0 in self.vace_layers + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # blocks + self.blocks = nn.ModuleList([ + BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, + block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None) + for i in range(self.num_layers) + ]) + + # vace blocks + self.vace_blocks = nn.ModuleList([ + VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, + self.cross_attn_norm, self.eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = nn.Conv3d( + self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + def forward_vace( + self, + x, + vace_context, + seq_len, + kwargs + ): + # embeddings + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + # arguments + new_kwargs = dict(x=x) + new_kwargs.update(kwargs) + + for block in self.vace_blocks: + c = block(c, **new_kwargs) + hints = torch.unbind(c)[:-1] + return hints + + def forward( + self, + x, + t, + vace_context, + context, + seq_len, + vace_context_scale=1.0, + clip_fea=None, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + # if self.model_type == 'i2v': + # assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + # if y is not None: + # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + # if clip_fea is not None: + # context_clip = self.img_emb(clip_fea) # bs x 257 x dim + # context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + hints = self.forward_vace(x, vace_context, seq_len, kwargs) + kwargs['hints'] = hints + kwargs['context_scale'] = vace_context_scale + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] \ No newline at end of file diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index 6e9a339..ba3fe7d 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -1,8 +1,10 @@ from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .vace_processor import VaceVideoProcessor __all__ = [ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', - 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' + 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', + 'VaceVideoProcessor' ] diff --git a/wan/utils/vace_processor.py b/wan/utils/vace_processor.py new file mode 100644 index 0000000..5f7224f --- /dev/null +++ b/wan/utils/vace_processor.py @@ -0,0 +1,270 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import numpy as np +from PIL import Image +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF + + +class VaceImageProcessor(object): + def __init__(self, downsample=None, seq_len=None): + self.downsample = downsample + self.seq_len = seq_len + + def _pillow_convert(self, image, cvt_type='RGB'): + if image.mode != cvt_type: + if image.mode == 'P': + image = image.convert(f'{cvt_type}A') + if image.mode == f'{cvt_type}A': + bg = Image.new(cvt_type, + size=(image.width, image.height), + color=(255, 255, 255)) + bg.paste(image, (0, 0), mask=image) + image = bg + else: + image = image.convert(cvt_type) + return image + + def _load_image(self, img_path): + if img_path is None or img_path == '': + return None + img = Image.open(img_path) + img = self._pillow_convert(img) + return img + + def _resize_crop(self, img, oh, ow, normalize=True): + """ + Resize, center crop, convert to tensor, and normalize. + """ + # resize and crop + iw, ih = img.size + if iw != ow or ih != oh: + # resize + scale = max(ow / iw, oh / ih) + img = img.resize( + (round(scale * iw), round(scale * ih)), + resample=Image.Resampling.LANCZOS + ) + assert img.width >= ow and img.height >= oh + + # center crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + + # normalize + if normalize: + img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) + return img + + def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): + return self._resize_crop(img, oh, ow, normalize) + + def load_image(self, data_key, **kwargs): + return self.load_image_batch(data_key, **kwargs) + + def load_image_pair(self, data_key, data_key2, **kwargs): + return self.load_image_batch(data_key, data_key2, **kwargs) + + def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs): + seq_len = self.seq_len if seq_len is None else seq_len + imgs = [] + for data_key in data_key_batch: + img = self._load_image(data_key) + imgs.append(img) + w, h = imgs[0].size + dh, dw = self.downsample[1:] + + # compute output size + scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) + oh = int(h * scale) // dh * dh + ow = int(w * scale) // dw * dw + assert (oh // dh) * (ow // dw) <= seq_len + imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] + return *imgs, (oh, ow) + + +class VaceVideoProcessor(object): + def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs): + self.downsample = downsample + self.min_area = min_area + self.max_area = max_area + self.min_fps = min_fps + self.max_fps = max_fps + self.zero_start = zero_start + self.keep_last = keep_last + self.seq_len = seq_len + assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) + + def set_area(self, area): + self.min_area = area + self.max_area = area + + def set_seq_len(self, seq_len): + self.seq_len = seq_len + + @staticmethod + def resize_crop(video: torch.Tensor, oh: int, ow: int): + """ + Resize, center crop and normalize for decord loaded video (torch.Tensor type) + + Parameters: + video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) + oh - target height (int) + ow - target width (int) + + Returns: + The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) + + Raises: + """ + # permute ([t, h, w, c] -> [t, c, h, w]) + video = video.permute(0, 3, 1, 2) + + # resize and crop + ih, iw = video.shape[2:] + if ih != oh or iw != ow: + # resize + scale = max(ow / iw, oh / ih) + video = F.interpolate( + video, + size=(round(scale * ih), round(scale * iw)), + mode='bicubic', + antialias=True + ) + assert video.size(3) >= ow and video.size(2) >= oh + + # center crop + x1 = (video.size(3) - ow) // 2 + y1 = (video.size(2) - oh) // 2 + video = video[:, :, y1:y1 + oh, x1:x1 + ow] + + # permute ([t, c, h, w] -> [c, t, h, w]) and normalize + video = video.transpose(0, 1).float().div_(127.5).sub_(1.) + return video + + def _video_preprocess(self, video, oh, ow): + return self.resize_crop(video, oh, ow) + + def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng): + target_fps = min(fps, self.max_fps) + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (int(duration * target_fps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = of / target_fps + begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration) + timestamps = np.linspace(begin, begin + target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] < frame_timestamps[None, :, 1] + ), axis=1).tolist() + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng): + duration = frame_timestamps[-1].mean() + x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box + h, w = y2 - y1, x2 - x1 + ratio = h / w + df, dh, dw = self.downsample + + area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw)) + of = min( + (len(frame_timestamps) - 1) // df + 1, + int(self.seq_len / area_z) + ) + + # deduce target shape of the [latent video] + target_area_z = min(area_z, int(self.seq_len / of)) + oh = round(np.sqrt(target_area_z * ratio)) + ow = int(target_area_z / oh) + of = (of - 1) * df + 1 + oh *= dh + ow *= dw + + # sample frame ids + target_duration = duration + target_fps = of / target_duration + timestamps = np.linspace(0., target_duration, of) + frame_ids = np.argmax(np.logical_and( + timestamps[:, None] >= frame_timestamps[None, :, 0], + timestamps[:, None] <= frame_timestamps[None, :, 1] + ), axis=1).tolist() + # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) + return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps + + + def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): + if self.keep_last: + return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng) + else: + return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng) + + def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): + return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) + + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs): + rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) + # read video + import decord + decord.bridge.set_bridge('torch') + readers = [] + for data_k in data_key_batch: + reader = decord.VideoReader(data_k) + readers.append(reader) + + fps = readers[0].get_avg_fps() + length = min([len(r) for r in readers]) + frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] + frame_timestamps = np.array(frame_timestamps, dtype=np.float32) + h, w = readers[0].next().shape[:2] + frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng) + + # preprocess video + videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] + videos = [self._video_preprocess(video, oh, ow) for video in videos] + return *videos, frame_ids, (oh, ow), fps + # return videos if len(videos) > 1 else videos[0] + + +def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device): + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_video is None and sub_src_mask is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device) + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + for j, ref_img in enumerate(ref_images): + if ref_img is not None and ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + src_ref_images[i][j] = white_canvas + return src_video, src_mask, src_ref_images diff --git a/wan/vace.py b/wan/vace.py new file mode 100644 index 0000000..1b94e88 --- /dev/null +++ b/wan/vace.py @@ -0,0 +1,717 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os +import sys +import gc +import math +import time +import random +import types +import logging +import traceback +from contextlib import contextmanager +from functools import partial + +from PIL import Image +import torchvision.transforms.functional as TF +import torch +import torch.nn.functional as F +import torch.cuda.amp as amp +import torch.distributed as dist +import torch.multiprocessing as mp +from tqdm import tqdm + +from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) +from .modules.vace_model import VaceWanModel +from .utils.vace_processor import VaceVideoProcessor + + +class WanVace(WanT2V): + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + ): + r""" + Initializes the Wan text-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + logging.info(f"Creating VaceWanModel from {checkpoint_dir}") + self.model = VaceWanModel.from_pretrained(checkpoint_dir) + self.model.eval().requires_grad_(False) + + if use_usp: + from xfuser.core.distributed import \ + get_sequence_parallel_world_size + + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in self.model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in self.model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + self.model.forward = types.MethodType(usp_dit_forward, self.model) + self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model) + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + if dit_fsdp: + self.model = shard_fn(self.model) + else: + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), + min_area=720*1280, + max_area=720*1280, + min_fps=config.sample_fps, + max_fps=config.sample_fps, + zero_start=True, + seq_len=75600, + keep_last=True) + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames) + else: + masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive) + reactive = vae.encode(reactive) + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs) + else: + ref_latent = vae.encode(refs) + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): + vae_stride = self.vae_stride if vae_stride is None else vae_stride + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device): + area = image_size[0] * image_size[1] + self.vid_proc.set_area(area) + if area == 720*1280: + self.vid_proc.set_seq_len(75600) + elif area == 480*832: + self.vid_proc.set_seq_len(32760) + else: + raise NotImplementedError(f'image_size {image_size} is not supported') + + image_sizes = [] + for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): + if sub_src_mask is not None and sub_src_video is not None: + src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask) + src_video[i] = src_video[i].to(device) + src_mask[i] = src_mask[i].to(device) + src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) + image_sizes.append(src_video[i].shape[2:]) + elif sub_src_video is None: + src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(image_size) + else: + src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) + src_video[i] = src_video[i].to(device) + src_mask[i] = torch.ones_like(src_video[i], device=device) + image_sizes.append(src_video[i].shape[2:]) + + for i, ref_images in enumerate(src_ref_images): + if ref_images is not None: + image_size = image_sizes[i] + for j, ref_img in enumerate(ref_images): + if ref_img is not None: + ref_img = Image.open(ref_img).convert("RGB") + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if ref_img.shape[-2:] != image_size: + canvas_height, canvas_width = image_size + ref_height, ref_width = ref_img.shape[-2:] + white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image + ref_img = white_canvas + src_ref_images[i][j] = ref_img.to(device) + return src_video, src_mask, src_ref_images + + def decode_latent(self, zs, ref_images=None, vae=None): + vae = self.vae if vae is None else vae + if ref_images is None: + ref_images = [None] * len(zs) + else: + assert len(zs) == len(ref_images) + + trimed_zs = [] + for z, refs in zip(zs, ref_images): + if refs is not None: + z = z[:, len(refs):, :, :] + trimed_zs.append(z) + + return vae.decode(trimed_zs) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + # F = frame_num + # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + # size[1] // self.vae_stride[1], + # size[0] // self.vae_stride[2]) + # + # seq_len = math.ceil((target_shape[2] * target_shape[3]) / + # (self.patch_size[1] * self.patch_size[2]) * + # target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks) + m0 = self.vace_encode_masks(input_masks, input_ref_images) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.decode_latent(x0, input_ref_images) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + + +class WanVaceMP(WanVace): + def __init__( + self, + config, + checkpoint_dir, + use_usp=False, + ulysses_size=None, + ring_size=None + ): + self.config = config + self.checkpoint_dir = checkpoint_dir + self.use_usp = use_usp + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + os.environ['RANK'] = '0' + os.environ['WORLD_SIZE'] = '1' + self.in_q_list = None + self.out_q = None + self.inference_pids = None + self.ulysses_size = ulysses_size + self.ring_size = ring_size + self.dynamic_load() + + self.device = 'cpu' if torch.cuda.is_available() else 'cpu' + self.vid_proc = VaceVideoProcessor( + downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]), + min_area=480 * 832, + max_area=480 * 832, + min_fps=self.config.sample_fps, + max_fps=self.config.sample_fps, + zero_start=True, + seq_len=32760, + keep_last=True) + + + def dynamic_load(self): + if hasattr(self, 'inference_pids') and self.inference_pids is not None: + return + gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count() + pmi_rank = int(os.environ['RANK']) + pmi_world_size = int(os.environ['WORLD_SIZE']) + in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)] + out_q = torch.multiprocessing.Manager().Queue() + initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)] + context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False) + all_initialized = False + while not all_initialized: + all_initialized = all(event.is_set() for event in initialized_events) + if not all_initialized: + time.sleep(0.1) + print('Inference model is initialized', flush=True) + self.in_q_list = in_q_list + self.out_q = out_q + self.inference_pids = context.pids() + self.initialized_events = initialized_events + + def transfer_data_to_cuda(self, data, device): + if data is None: + return None + else: + if isinstance(data, torch.Tensor): + data = data.to(device) + elif isinstance(data, list): + data = [self.transfer_data_to_cuda(subdata, device) for subdata in data] + elif isinstance(data, dict): + data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()} + return data + + def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env): + try: + world_size = pmi_world_size * gpu_infer + rank = pmi_rank * gpu_infer + gpu + print("world_size", world_size, "rank", rank, flush=True) + + torch.cuda.set_device(gpu) + dist.init_process_group( + backend='nccl', + init_method='env://', + rank=rank, + world_size=world_size + ) + + from xfuser.core.distributed import (initialize_model_parallel, + init_distributed_environment) + init_distributed_environment( + rank=dist.get_rank(), world_size=dist.get_world_size()) + + initialize_model_parallel( + sequence_parallel_degree=dist.get_world_size(), + ring_degree=self.ring_size or 1, + ulysses_degree=self.ulysses_size or 1 + ) + + num_train_timesteps = self.config.num_train_timesteps + param_dtype = self.config.param_dtype + shard_fn = partial(shard_model, device_id=gpu) + text_encoder = T5EncoderModel( + text_len=self.config.text_len, + dtype=self.config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint), + tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer), + shard_fn=shard_fn if True else None) + text_encoder.model.to(gpu) + vae_stride = self.config.vae_stride + patch_size = self.config.patch_size + vae = WanVAE( + vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint), + device=gpu) + logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") + model = VaceWanModel.from_pretrained(self.checkpoint_dir) + model.eval().requires_grad_(False) + + if self.use_usp: + from xfuser.core.distributed import get_sequence_parallel_world_size + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace) + for block in model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + for block in model.vace_blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + model.forward = types.MethodType(usp_dit_forward, model) + model.forward_vace = types.MethodType(usp_dit_forward_vace, model) + sp_size = get_sequence_parallel_world_size() + else: + sp_size = 1 + + dist.barrier() + model = shard_fn(model) + sample_neg_prompt = self.config.sample_neg_prompt + + torch.cuda.empty_cache() + event = initialized_events[gpu] + in_q = in_q_list[gpu] + event.set() + + while True: + item = in_q.get() + input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \ + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item + input_frames = self.transfer_data_to_cuda(input_frames, gpu) + input_masks = self.transfer_data_to_cuda(input_masks, gpu) + input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu) + + if n_prompt == "": + n_prompt = sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=gpu) + seed_g.manual_seed(seed) + + context = text_encoder([input_prompt], gpu) + context_null = text_encoder([n_prompt], gpu) + + # vace context encode + z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae) + m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride) + z = self.vace_latent(z0, m0) + + target_shape = list(z0[0].shape) + target_shape[0] = int(target_shape[0] / 2) + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=gpu, + generator=seed_g) + ] + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (patch_size[1] * patch_size[2]) * + target_shape[1] / sp_size) * sp_size + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=gpu, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=gpu, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + model.to(gpu) + noise_pred_cond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[ + 0] + noise_pred_uncond = model( + latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, + **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + torch.cuda.empty_cache() + x0 = latents + if rank == 0: + videos = self.decode_latent(x0, input_ref_images, vae=vae) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + if rank == 0: + out_q.put(videos[0].cpu()) + + except Exception as e: + trace_info = traceback.format_exc() + print(trace_info, flush=True) + print(e, flush=True) + + + + def generate(self, + input_prompt, + input_frames, + input_masks, + input_ref_images, + size=(1280, 720), + frame_num=81, + context_scale=1.0, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + + input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, + shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model) + for in_q in self.in_q_list: + in_q.put(input_data) + value_output = self.out_q.get() + + return value_output