diff --git a/README.md b/README.md index 5a98b82..1a097a0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,4 @@ -# Wan2.1 GP - - -

- 💜 Wan    |    🖥️ GitHub    |   🤗 Hugging Face   |   🤖 ModelScope   |    📑 Paper (Coming soon)    |    📑 Blog    |   💬 WeChat Group   |    📖 Discord   -
+# WanGP -----

@@ -15,6 +10,7 @@ ## 🔥 Latest News!! +* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results. * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you ! - A new UI, tabs were replaced by a Dropdown box to easily switch models - A new queuing system that lets you stack in a queue as many text2video, imag2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature @@ -254,7 +250,7 @@ Each preset, is a file with ".lset" extension stored in the loras directory and Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server: ```bash -python wgp.py.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder +python wgp.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder ``` You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer. diff --git a/wan/image2video.py b/wan/image2video.py index b688676..99e0959 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -319,7 +319,6 @@ class WanI2V: arg_c = { 'context': [context[0]], 'clip_fea': clip_context, - 'seq_len': max_seq_len, 'y': [y], 'freqs' : freqs, 'pipeline' : self, @@ -329,7 +328,6 @@ class WanI2V: arg_null = { 'context': context_null, 'clip_fea': clip_context, - 'seq_len': max_seq_len, 'y': [y], 'freqs' : freqs, 'pipeline' : self, @@ -340,7 +338,6 @@ class WanI2V: 'context': [context[0]], 'context2': context_null, 'clip_fea': clip_context, - 'seq_len': max_seq_len, 'y': [y], 'freqs' : freqs, 'pipeline' : self, diff --git a/wan/modules/attention.py b/wan/modules/attention.py index b795b06..25a8ef7 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -152,91 +152,52 @@ __all__ = [ def pay_attention( qkv_list, - # q, - # k, - # v, - q_lens=None, - k_lens=None, dropout_p=0., softmax_scale=None, - q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, - dtype=torch.bfloat16, version=None, force_attention= None, cross_attn= False ): - """ - q: [B, Lq, Nq, C1]. - k: [B, Lk, Nk, C1]. - v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. - q_lens: [B]. - k_lens: [B]. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - causal: bool. Whether to apply causal attention mask. - window_size: (left right). If not (-1, -1), apply sliding window local attention. - deterministic: bool. If True, slightly slower and uses more memory. - dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. - """ attn = offload.shared_state["_attention"] if force_attention== None else force_attention q,k,v = qkv_list qkv_list.clear() - half_dtypes = (torch.float16, torch.bfloat16) - assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + assert b==1 + q = q.squeeze(0) + k = k.squeeze(0) + v = v.squeeze(0) - def half(x): - return x if x.dtype in half_dtypes else x.to(dtype) - - # preprocess query - if q_lens is None: - q = half(q.flatten(0, 1)) - q_lens = torch.tensor( - [lq] * b, dtype=torch.int32).to( - device=q.device, non_blocking=True) - else: - q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) - - # preprocess key, value - if k_lens is None: - k = half(k.flatten(0, 1)) - v = half(v.flatten(0, 1)) - k_lens = torch.tensor( - [lk] * b, dtype=torch.int32).to( - device=k.device, non_blocking=True) - else: - k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) - v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) q = q.to(v.dtype) k = k.to(v.dtype) - if q_scale is not None: - q = q * q_scale + # if q_scale is not None: + # q = q * q_scale if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: warnings.warn( 'Flash attention 3 is not available, use flash attention 2 instead.' ) + if attn=="sage" or attn=="flash": + cu_seqlens_q = torch.tensor([0, lq], dtype=torch.int32, device="cuda") + cu_seqlens_k = torch.tensor([0, lk], dtype=torch.int32, device="cuda") + # apply attention if attn=="sage": x = sageattn_varlen_wrapper( q=q, k=k, v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_kv=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_q= cu_seqlens_q, + cu_seqlens_kv= cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, ).unflatten(0, (b, lq)) @@ -314,10 +275,8 @@ def pay_attention( q=q, k=k, v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_q= cu_seqlens_q, + cu_seqlens_kv= cu_seqlens_k, seqused_q=None, seqused_k=None, max_seqlen_q=lq, @@ -330,10 +289,8 @@ def pay_attention( q=q, k=k, v=v, - cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), - cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( - 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_q= [0, lq], + cu_seqlens_kv=[0, lk], max_seqlen_q=lq, max_seqlen_k=lk, dropout_p=dropout_p, diff --git a/wan/modules/model.py b/wan/modules/model.py index 0ba16ae..37b8989 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -1,6 +1,6 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import math - +from einops import rearrange import torch import torch.cuda.amp as amp import torch.nn as nn @@ -167,11 +167,10 @@ class WanSelfAttention(nn.Module): self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, seq_lens, grid_sizes, freqs): + def forward(self, xlist, grid_sizes, freqs): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] - seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ @@ -196,10 +195,6 @@ class WanSelfAttention(nn.Module): del q,k,v x = pay_attention( qkv_list, - # q=q, - # k=k, - # v=v, - # k_lens=seq_lens, window_size=self.window_size) # output x = x.flatten(2) @@ -209,12 +204,11 @@ class WanSelfAttention(nn.Module): class WanT2VCrossAttention(WanSelfAttention): - def forward(self, xlist, context, context_lens): + def forward(self, xlist, context): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] """ x = xlist[0] xlist.clear() @@ -233,7 +227,7 @@ class WanT2VCrossAttention(WanSelfAttention): # compute attention qvl_list=[q, k, v] del q, k, v - x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True) + x = pay_attention(qvl_list, cross_attn= True) # output x = x.flatten(2) @@ -256,12 +250,11 @@ class WanI2VCrossAttention(WanSelfAttention): # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, context, context_lens): + def forward(self, xlist, context): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] """ ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep ! @@ -287,7 +280,7 @@ class WanI2VCrossAttention(WanSelfAttention): qkv_list = [q, k, v] del k,v - x = pay_attention(qkv_list, k_lens=context_lens) + x = pay_attention(qkv_list) k_img = self.k_img(context_img) self.norm_k_img(k_img) @@ -295,7 +288,7 @@ class WanI2VCrossAttention(WanSelfAttention): v_img = self.v_img(context_img).view(b, -1, n, d) qkv_list = [q, k_img, v_img] del q, k_img, v_img - img_x = pay_attention(qkv_list, k_lens=None) + img_x = pay_attention(qkv_list) # compute attention @@ -362,30 +355,26 @@ class WanAttentionBlock(nn.Module): self, x, e, - seq_lens, grid_sizes, freqs, context, - context_lens, hints= None, context_scale=1.0, + cam_emb= None ): r""" Args: x(Tensor): Shape [B, L, C] e(Tensor): Shape [B, 6, C] - seq_lens(Tensor): Shape [B], length of each sequence in batch grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ hint = None if self.block_id is not None and hints is not None: kwargs = { - "seq_lens" : seq_lens, "grid_sizes" : grid_sizes, "freqs" :freqs, "context" : context, - "context_lens" : context_lens, "e" : e, } if self.block_id == 0: @@ -394,20 +383,31 @@ class WanAttentionBlock(nn.Module): hint = self.vace(hints, None, **kwargs) e = (self.modulation + e).chunk(6, dim=1) - + # self-attention x_mod = self.norm1(x) x_mod *= 1 + e[1] x_mod += e[0] + if cam_emb != None: + cam_emb = self.cam_encoder(cam_emb) + cam_emb = cam_emb.repeat(1, 2, 1) + cam_emb = cam_emb.unsqueeze(2).unsqueeze(3).repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1) + cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d') + x_mod += cam_emb + xlist = [x_mod] del x_mod - y = self.self_attn( xlist, seq_lens, grid_sizes,freqs) + y = self.self_attn( xlist, grid_sizes, freqs) + if cam_emb != None: + y = self.projector(y) + # x = x + gate_msa * self.projector(self.self_attn(input_x, freqs)) + x.addcmul_(y, e[2]) del y y = self.norm3(x) ylist= [y] del y - x += self.cross_attn(ylist, context, context_lens) + x += self.cross_attn(ylist, context) y = self.norm2(x) y *= 1 + e[4] @@ -552,6 +552,7 @@ class WanModel(ModelMixin, ConfigMixin): qk_norm=True, cross_attn_norm=True, eps=1e-6, + recammaster = False ): r""" Initialize the diffusion model backbone. @@ -666,6 +667,15 @@ class WanModel(ModelMixin, ConfigMixin): self.vace_patch_embedding = nn.Conv3d( self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size ) + if recammaster : + dim=self.blocks[0].self_attn.q.weight.shape[0] + for block in self.blocks: + block.cam_encoder = nn.Linear(12, dim) + block.projector = nn.Linear(dim, dim) + block.cam_encoder.weight.data.zero_() + block.cam_encoder.bias.data.zero_() + block.projector.weight = nn.Parameter(torch.eye(dim)) + block.projector.bias = nn.Parameter(torch.zeros(dim)) def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0): @@ -716,7 +726,6 @@ class WanModel(ModelMixin, ConfigMixin): x, t, context, - seq_len, vace_context = None, vace_context_scale=1.0, clip_fea=None, @@ -729,28 +738,9 @@ class WanModel(ModelMixin, ConfigMixin): max_steps = 0, slg_layers=None, callback = None, + cam_emb: torch.Tensor = None, ): - r""" - Forward pass through the diffusion model - Args: - x (List[Tensor]): - List of input video tensors, each with shape [C_in, F, H, W] - t (Tensor): - Diffusion timesteps tensor of shape [B] - context (List[Tensor]): - List of text embeddings each with shape [L, C] - seq_len (`int`): - Maximum sequence length for positional encoding - clip_fea (Tensor, *optional*): - CLIP image features for image-to-video mode - y (List[Tensor], *optional*): - Conditional video inputs for image-to-video mode, same shape as x - - Returns: - List[Tensor]: - List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] - """ if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params @@ -775,15 +765,7 @@ class WanModel(ModelMixin, ConfigMixin): x = [u.flatten(2).transpose(1, 2) for u in x] - seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) - assert seq_lens.max() <= seq_len - if len(x)==1 and seq_len == x[0].size(1): - x = x[0] - else: - x = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in x - ]) + x = x[0] # time embeddings e = self.time_embedding( @@ -791,7 +773,6 @@ class WanModel(ModelMixin, ConfigMixin): e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype) # context - context_lens = None context = self.text_embedding( torch.stack([ torch.cat( @@ -825,10 +806,9 @@ class WanModel(ModelMixin, ConfigMixin): # arguments kwargs = dict( - seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=freqs, - context_lens=context_lens, + cam_emb = cam_emb ) if vace_context == None: @@ -837,13 +817,7 @@ class WanModel(ModelMixin, ConfigMixin): # embeddings c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] - if (len(c) == 1 and seq_len == c[0].size(1)): - c = c[0] - else: - c = torch.cat([ - torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], - dim=1) for u in c - ]) + c = c[0] kwargs['context_scale'] = vace_context_scale hints_list = [ [c] for _ in range(len(x_list)) ] diff --git a/wan/text2video.py b/wan/text2video.py index b8140f1..725d51f 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -10,6 +10,7 @@ from contextlib import contextmanager from functools import partial from mmgp import offload import torch +import torch.nn as nn import torch.cuda.amp as amp import torch.distributed as dist from tqdm import tqdm @@ -106,6 +107,9 @@ class WanT2V: from mmgp import offload self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) + # offload.load_model_data(self.model, "recam.ckpt") + # self.model.cpu() + # offload.save_model(self.model, "recam.safetensors") if self.dtype == torch.float16 and not "fp16" in model_filename: self.model.to(self.dtype) # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) @@ -278,7 +282,9 @@ class WanT2V: input_prompt, input_frames= None, input_masks = None, - input_ref_images = None, + input_ref_images = None, + source_video=None, + target_camera=None, context_scale=1.0, size=(1280, 720), frame_num=81, @@ -340,18 +346,19 @@ class WanT2V: seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - if not self.t5_cpu: - # self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() - else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context_null = self.text_encoder([n_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] - context_null = [t.to(self.device) for t in context_null] - + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if target_camera != None: + size = (source_video.shape[2], source_video.shape[1]) + source_video = source_video.to(dtype=self.dtype , device=self.device) + source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) + source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) + del source_video + # Process target camera (recammaster) + from wan.utils.cammmaster_tools import get_camera_embedding + cam_emb = get_camera_embedding(target_camera) + cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) + if input_frames != None: # vace context encode input_frames = [u.to(self.device) for u in input_frames] @@ -377,22 +384,7 @@ class WanT2V: context = [u.to(self.dtype) for u in context] context_null = [u.to(self.dtype) for u in context_null] - noise = [ - torch.randn( - target_shape[0], - target_shape[1], - target_shape[2], - target_shape[3], - dtype=torch.float32, - device=self.device, - generator=seed_g) - ] - - @contextmanager - def noop_no_sync(): - yield - - no_sync = getattr(self.model, 'no_sync', noop_no_sync) + noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] # evaluation mode @@ -419,11 +411,24 @@ class WanT2V: # sample videos latents = noise + del noise batch_size =len(latents) - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) - arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} - arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} - arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} + if target_camera != None: + shape = list(latents[0].shape[1:]) + shape[0] *= 2 + freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) + else: + freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) + arg_c = {'context': context, 'freqs': freqs, 'pipeline': self} + arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self} + arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self} + + if target_camera != None: + recam_dict = {'cam_emb': cam_emb} + arg_c.update(recam_dict) + arg_null.update(recam_dict) + arg_both.update(recam_dict) + if input_frames != None: vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} arg_c.update(vace_dict) @@ -435,7 +440,10 @@ class WanT2V: if callback != None: callback(-1, True) for i, t in enumerate(tqdm(timesteps)): - latent_model_input = latents + if target_camera != None: + latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] + else: + latent_model_input = latents slg_layers_local = None if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): slg_layers_local = slg_layers @@ -443,7 +451,6 @@ class WanT2V: offload.set_step_no_for_lora(self.model, i) timestep = torch.stack(timestep) - # self.model.to(self.device) if joint_pass: noise_pred_cond, noise_pred_uncond = self.model( latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) @@ -459,7 +466,7 @@ class WanT2V: if self._interrupt: return None - del latent_model_input + # del latent_model_input # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ noise_pred_text = noise_pred_cond @@ -478,7 +485,7 @@ class WanT2V: del noise_pred_uncond temp_x0 = sample_scheduler.step( - noise_pred.unsqueeze(0), + noise_pred[:, :target_shape[1]].unsqueeze(0), t, latents[0].unsqueeze(0), return_dict=False, @@ -490,24 +497,14 @@ class WanT2V: callback(i, False) x0 = latents - if offload_model: - self.model.cpu() - torch.cuda.empty_cache() - if self.rank == 0: - if input_frames == None: - videos = self.vae.decode(x0, VAE_tile_size) - else: - videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) + if input_frames == None: + videos = self.vae.decode(x0, VAE_tile_size) + else: + videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) - - del noise, latents + del latents del sample_scheduler - if offload_model: - gc.collect() - torch.cuda.synchronize() - if dist.is_initialized(): - dist.barrier() return videos[0] if self.rank == 0 else None diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 5149464..92a7cb5 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -69,25 +69,21 @@ def remove_background(img, session=None): -def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ): if rm_background: session = new_session() output_list =[] for img in img_list: width, height = img.size - white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 ) - scale = min(canvas_height / height, canvas_width / width) - new_height = int(height * scale) - new_width = int(width * scale) + scale = (budget_height * budget_width / (height * width))**(1/2) + new_height = int( round(height * scale / 16) * 16) + new_width = int( round(width * scale / 16) * 16) + resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) if rm_background: resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image) - img = Image.fromarray(white_canvas) - output_list.append(img) + output_list.append(resized_image) return output_list diff --git a/wgp.py b/wgp.py index 09dca96..b846106 100644 --- a/wgp.py +++ b/wgp.py @@ -174,7 +174,25 @@ def process_prompt_and_add_tasks(state, model_choice): sliding_window_repeat = inputs["sliding_window_repeat"] sliding_window = sliding_window_repeat > 0 - if "Vace" in model_filename: + + if "recam" in model_filename: + video_source = inputs["video_source"] + if video_source == None: + gr.Info("You must provide a Source Video") + return + frames = get_resampled_video(video_source, 0, 81) + if len(frames)<81: + gr.Info("Recammaster source video should be at least 81 frames one the resampling at 16 fps has been done") + return + for single_prompt in prompts: + extra_inputs = { + "prompt" : single_prompt, + "video_source" : video_source, + } + inputs.update(extra_inputs) + add_video_task(**inputs) + + elif "Vace" in model_filename: video_prompt_type = inputs["video_prompt_type"] image_refs = inputs["image_refs"] video_guide = inputs["video_guide"] @@ -334,14 +352,8 @@ def process_prompt_and_add_tasks(state, model_choice): queue= gen.get("queue", []) return update_queue_data(queue) -def add_video_task(**inputs): - global task_id - state = inputs["state"] - gen = get_gen_info(state) - queue = gen["queue"] - task_id += 1 - current_task_id = task_id - inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"] +def get_preview_images(inputs): + inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"] start_image_data = None end_image_data = None for name in inputs_to_query: @@ -353,6 +365,17 @@ def add_video_task(**inputs): else: end_image_data = image break + return start_image_data, end_image_data + +def add_video_task(**inputs): + global task_id + state = inputs["state"] + gen = get_gen_info(state) + queue = gen["queue"] + task_id += 1 + current_task_id = task_id + + start_image_data, end_image_data = get_preview_images(inputs) queue.append({ "id": current_task_id, @@ -434,7 +457,7 @@ def save_queue_action(state): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask"] + video_keys = ["video_guide", "video_mask", "video_source"] for key in image_keys: images_pil = params_copy.get(key) @@ -595,7 +618,7 @@ def load_queue_action(filepath, state): params['state'] = state image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask"] + video_keys = ["video_guide", "video_mask", "video_source"] loaded_pil_images = {} loaded_video_paths = {} @@ -652,20 +675,10 @@ def load_queue_action(filepath, state): print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}") params.pop(key, None) + primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params) - primary_preview_pil_list = loaded_pil_images.get("image_start") or loaded_pil_images.get("image_refs") - secondary_preview_pil_list = loaded_pil_images.get("image_end") - - primary_preview_pil = None - if primary_preview_pil_list: - primary_preview_pil = primary_preview_pil_list[0] if isinstance(primary_preview_pil_list, list) else primary_preview_pil_list - - secondary_preview_pil = None - if secondary_preview_pil_list: - secondary_preview_pil = secondary_preview_pil_list[0] if isinstance(secondary_preview_pil_list, list) else secondary_preview_pil_list - - start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None - end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None + start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list[0] else None + end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list[0] else None top_level_start_image = params.get("image_start") or params.get("image_refs") top_level_end_image = params.get("image_end") @@ -818,7 +831,7 @@ def autosave_queue(): task_id_s = task.get('id', f"task_{task_index}") image_keys = ["image_start", "image_end", "image_refs"] - video_keys = ["video_guide", "video_mask"] + video_keys = ["video_guide", "video_mask", "video_source"] for key in image_keys: images_pil = params_copy.get(key) @@ -1352,13 +1365,13 @@ quantizeTransformer = args.quantize_transformer check_loras = args.check_loras ==1 advanced = args.advanced -transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"] +transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"] transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ] transformer_choices = transformer_choices_t2v + transformer_choices_i2v text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"] server_config_filename = "wgp_config.json" if not os.path.isdir("settings"): - os.mkdir("settings") + os.mkdir("settings") if os.path.isfile("t2v_settings.json"): for f in glob.glob(os.path.join(".", "*_settings.json*")): target_file = os.path.join("settings", Path(f).parts[-1] ) @@ -1391,30 +1404,16 @@ else: server_config = json.loads(text) -model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp"] +model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", - "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B" } + "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B" } def get_model_type(model_filename): - if "text2video" in model_filename and "14B" in model_filename: - return "t2v" - elif "text2video" in model_filename and "1.3B" in model_filename: - return "t2v_1.3B" - elif "Fun_InP" in model_filename and "1.3B" in model_filename: - return "fun_inp_1.3B" - elif "Fun_InP" in model_filename and "14B" in model_filename: - return "fun_inp" - elif "image2video_480p" in model_filename : - return "i2v" - elif "image2video_720p" in model_filename : - return "i2v_720p" - elif "Vace" in model_filename and "1.3B" in model_filename: - return "vace_1.3B" - elif "Vace" in model_filename and "14B" in model_filename: - return "vace" - else: - raise Exception("Unknown model:" + model_filename) + for model_type, signature in model_signatures.items(): + if signature in model_filename: + return model_type + raise Exception("Unknown model:" + model_filename) def test_class_i2v(model_filename): return "image2video" in model_filename or "Fun_InP" in model_filename @@ -1862,6 +1861,9 @@ def get_model_name(model_filename): elif "image" in model_filename: model_name = "Wan2.1 image2video" model_name += " 720p" if "720p" in model_filename else " 480p" + elif "recam" in model_filename: + model_name = "ReCamMaster" + model_name += " 14B" if "14B" in model_filename else " 1.3B" else: model_name = "Wan2.1 text2video" model_name += " 14B" if "14B" in model_filename else " 1.3B" @@ -2145,9 +2147,7 @@ def convert_image(image): return cast(Image, ImageOps.exif_transpose(image)) - -def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0): - +def get_resampled_video(video_in, start_frame, max_frames): from wan.utils.utils import resample import decord @@ -2158,15 +2158,26 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame) frames_list = reader.get_batch(frame_nos) + return frames_list + +def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False): + + frames_list = get_resampled_video(video_in, start_frame, max_frames) + if len(frames_list) == 0: return None frame_height, frame_width, _ = frames_list[0].shape - scale = ((height * width ) / (frame_height * frame_width))**(1/2) - # scale = min(height / frame_height, width / frame_width) + if fit_canvas : + scale = min(height / frame_height, width / frame_width) + else: + scale = ((height * width ) / (frame_height * frame_width))**(1/2) new_height = (int(frame_height * scale) // 16) * 16 new_width = (int(frame_width * scale) // 16) * 16 + # if fit_canvas : + # new_height = height + # new_width = width processed_frames_list = [] for frame in frames_list: @@ -2188,12 +2199,17 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr "PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt" } anno_ins = DepthVideoAnnotator(cfg_dict) - else: + elif process_type=="gray": from preprocessing.gray import GrayVideoAnnotator cfg_dict = {} anno_ins = GrayVideoAnnotator(cfg_dict) - - np_frames = anno_ins.forward(processed_frames_list) + else: + anno_ins = None + + if anno_ins == None: + np_frames = [np.array(frame) for frame in processed_frames_list] + else: + np_frames = anno_ins.forward(processed_frames_list) # from preprocessing.dwpose.pose import save_one_video # save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None) @@ -2281,6 +2297,8 @@ def generate_video( image_refs, video_guide, video_mask, + camera_type, + video_source, keep_frames, sliding_window_repeat, sliding_window_overlap, @@ -2439,8 +2457,11 @@ def generate_video( trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] else: raise gr.Error("Teacache not supported for this model") - - + source_video = None + target_camera = None + if "recam" in model_filename: + source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True) + target_camera = camera_type import random if seed == None or seed <0: seed = random.randint(0, 999999999) @@ -2532,7 +2553,6 @@ def generate_video( prompts_max = gen["prompts_max"] status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window) - yield status gen["progress_status"] = status gen["progress_phase"] = (" - Encoding Prompt", -1 ) @@ -2582,6 +2602,8 @@ def generate_video( input_frames = src_video, input_ref_images= src_ref_images, input_masks = src_mask, + source_video= source_video, + target_camera= target_camera, frame_num=(video_length // 4)* 4 + 1, size=(width, height), shift=flow_shift, @@ -2755,7 +2777,8 @@ def generate_video( state['update_gallery'] = True if not sliding_window: seed += 1 - + yield status + if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) offload.unload_loras_from_model(trans) @@ -2858,7 +2881,7 @@ def update_status(state): gen = get_gen_info(state) prompt_no = gen["prompt_no"] prompts_max = gen.get("prompts_max",0) - total_generation = gen["total_generation"] + total_generation = gen.get("total_generation", 1) repeat_no = gen["repeat_no"] sliding_window = gen["sliding_window"] status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window) @@ -3186,7 +3209,7 @@ def prepare_inputs_dict(target, inputs ): if target == "state": return inputs - unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"] + unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "video_source", "video_mask"] for k in unsaved_params: inputs.pop(k) @@ -3200,6 +3223,9 @@ def prepare_inputs_dict(target, inputs ): inputs.pop("image_prompt_type") + if not "recam" in model_filename: + inputs.pop("camera_type") + if not "Vace" in model_filename: unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"] for k in unsaved_params: @@ -3243,6 +3269,8 @@ def save_inputs( image_refs, video_guide, video_mask, + camera_type, + video_source, keep_frames, sliding_window_repeat, sliding_window_overlap, @@ -3551,6 +3579,26 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non else: image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + with gr.Column(visible= "recam" in model_filename ) as recam_column: + camera_type = gr.Dropdown( + choices=[ + ("Pan Right", 1), + ("Pan Left", 2), + ("Tilt Up", 3), + ("Tilt Down", 4), + ("Zoom In", 5), + ("Zoom Out", 6), + ("Translate Up (with rotation)", 7), + ("Translate Down (with rotation)", 8), + ("Arc Left (with rotation)", 9), + ("Arc Right (with rotation)", 10), + ], + value=ui_defaults.get("camera_type", 1), + label="Camera Movement Type", scale = 3 + ) + video_source = gr.Video(label= "Video Source", value= ui_defaults.get("video_source", None),) + + with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) @@ -3657,7 +3705,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non label="Resolution" ) with gr.Row(): - video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)") + if "recam" in model_filename: + video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False) + else: + video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True) + num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps") @@ -3794,7 +3846,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)") - with gr.Tab("Miscellaneous"): + with gr.Tab("Miscellaneous", visible= not "recam" in model_filename): gr.Markdown("With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model") RIFLEx_setting = gr.Dropdown( choices=[ @@ -4007,7 +4059,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, - prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab] # show_advanced presets_column, + prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab, + video_prompt_type_video_guide, video_prompt_type_image_refs, recam_column] # show_advanced presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -4615,7 +4668,7 @@ def create_demo(): theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo: - gr.Markdown("

WanGP v4.0 by DeepBeepMeep ") # (Updates)

") + gr.Markdown("

WanGP v4.1 by DeepBeepMeep ") # (Updates)

") global model_list tab_state = gr.State({ "tab_no":0 })