diff --git a/configs/animate.json b/configs/animate.json new file mode 100644 index 0000000..7e98ca9 --- /dev/null +++ b/configs/animate.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "motion_encoder_dim": 512 +} \ No newline at end of file diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json new file mode 100644 index 0000000..4983ced --- /dev/null +++ b/configs/lucy_edit.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 96, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/defaults/animate.json b/defaults/animate.json new file mode 100644 index 0000000..fee26d7 --- /dev/null +++ b/defaults/animate.json @@ -0,0 +1,13 @@ +{ + "model": { + "name": "Wan2.2 Animate", + "architecture": "animate", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" + ], + "group": "wan2_2" + } +} \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json new file mode 100644 index 0000000..6344dff --- /dev/null +++ b/defaults/lucy_edit.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.2 Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "video_length": 81, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 30, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index 064c2b4..fa69f82 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -7,6 +7,7 @@ "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, + "prompt" : "Put the person into a clown outfit.", "video_length": 121, "guidance_scale": 1, "flow_shift": 3, diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index be982aa..1472886 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mmgp import offload import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -387,7 +386,8 @@ class QwenImagePipeline(): #DiffusionPipeline return latent_image_ids.to(device=device, dtype=dtype) @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): + def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -479,7 +479,7 @@ class QwenImagePipeline(): #DiffusionPipeline height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, 1, num_channels_latents, height, width) + shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None if image is not None: @@ -499,10 +499,7 @@ class QwenImagePipeline(): #DiffusionPipeline else: image_latents = torch.cat([image_latents], dim=0) - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) + image_latents = self._pack_latents(image_latents) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -511,7 +508,7 @@ class QwenImagePipeline(): #DiffusionPipeline ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents) else: latents = latents.to(device=device, dtype=dtype) @@ -713,11 +710,12 @@ class QwenImagePipeline(): #DiffusionPipeline image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) height, width = image_height, image_width - image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] - image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") - image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) prompt_image = image if image.size != (image_width, image_height): @@ -822,6 +820,7 @@ class QwenImagePipeline(): #DiffusionPipeline negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) morph, first_step = False, 0 + lanpaint_proc = None if image_mask_latents is not None: randn = torch.randn_like(original_image_latents) if denoising_strength < 1.: @@ -833,7 +832,8 @@ class QwenImagePipeline(): #DiffusionPipeline timesteps = timesteps[first_step:] self.scheduler.timesteps = timesteps self.scheduler.sigmas= self.scheduler.sigmas[first_step:] - + # from shared.inpainting.lanpaint import LanPaint + # lanpaint_proc = LanPaint() # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) @@ -847,48 +847,52 @@ class QwenImagePipeline(): #DiffusionPipeline offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: latent_noise_factor = t/1000 latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor - self._current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) + latents_dtype = latents.dtype - if do_true_cfg and joint_pass: - noise_pred, neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - ) - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - )[0] - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] + # latent_model_input = latents + def denoise(latent_model_input, true_cfg_scale): + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + do_true_cfg = true_cfg_scale > 1 + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, #!!!! + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + neg_noise_pred = None + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( @@ -902,27 +906,43 @@ class QwenImagePipeline(): #DiffusionPipeline attention_kwargs=self.attention_kwargs, **kwargs )[0] - if neg_noise_pred == None: return None + if neg_noise_pred == None: return None, None neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + return noise_pred, neg_noise_pred + def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): + if do_true_cfg: + comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred) + if comb_pred == None: return None - if do_true_cfg: - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if comb_pred == None: return None + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - neg_noise_pred = None + return noise_pred + + + if lanpaint_proc is not None and i<=3: + latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8) + if latents is None: return None + + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None + noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + neg_noise_pred = None # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = None + if image_mask_latents is not None: - next_t = timesteps[i+1] if i0: + msk[:, :nb_frames_unchanged] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1,2)[0] + return msk + def generate(self, input_prompt, input_frames= None, input_masks = None, - input_ref_images = None, + input_ref_images = None, + input_ref_masks = None, + input_faces = None, input_video = None, image_start = None, image_end = None, @@ -541,14 +506,18 @@ class WanAny2V: infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] recam = model_type in ["recam_1.3B"] - ti2v = model_type in ["ti2v_2_2"] + ti2v = model_type in ["ti2v_2_2", "lucy_edit"] + lucy_edit= model_type in ["lucy_edit"] + animate= model_type in ["animate"] start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = None + extended_overlapped_latents = clip_image_start = clip_image_end = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + extended_input_dim = 0 + ref_images_before = False # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False @@ -598,17 +567,7 @@ class WanAny2V: if image_end is not None: img_end_frame = image_end.unsqueeze(1).to(self.device) - - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None + clip_image_start, clip_image_end = image_start, image_end if any_end_frame: enc= torch.concat([ @@ -647,21 +606,62 @@ class WanAny2V: if infinitetalk: lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - # if control_pre_frames_count != pre_frames_count: lat_y = input_video = None kwargs.update({ 'y': y}) - if not clip_context is None: - kwargs.update({'clip_fea': clip_context}) - # Recam Master - if recam: - target_camera = model_mode - height,width = input_frames.shape[-2:] - input_frames = input_frames.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + # Animate + if animate: + pose_pixels = input_frames * input_masks + input_masks = 1. - input_masks + pose_pixels -= input_masks + save_video(pose_pixels, "pose.mp4") + pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) + input_frames = input_frames * input_masks + if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + if prefix_frames_count > 0: + input_frames[:, :prefix_frames_count] = input_video + input_masks[:, :prefix_frames_count] = 1 + save_video(input_frames, "input_frames.mp4") + save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) + msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) + msk = torch.concat([msk_ref, msk_control], dim=1) + clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) + lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + y = torch.concat([msk, lat_y]) + kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) + lat_y = msk = msk_control = msk_ref = pose_pixels = None + ref_images_before = True + ref_images_count = 1 + lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1 + + # Clip image + if hasattr(self, "clip") and clip_image_start is not None: + clip_image_size = self.clip.model.image_size + clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size) + clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([clip_image_start[:, None, :, :]]) + clip_image_start = clip_image_end = None + kwargs.update({'clip_fea': clip_context}) + + # Recam Master & Lucy Edit + if recam or lucy_edit: + frame_num, height,width = input_frames.shape[-3:] + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + frame_num = (lat_frames -1) * self.vae_stride[0] + 1 + input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device) + extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + extended_input_dim = 2 if recam else 1 del input_frames + + if recam: # Process target camera (recammaster) + target_camera = model_mode from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) @@ -715,6 +715,8 @@ class WanAny2V: height, width = input_video.shape[-2:] source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True + if extended_input_dim > 0: + extended_latents[:, :, :source_latents.shape[2]] = source_latents # Vace if vace : @@ -722,6 +724,7 @@ class WanAny2V: input_frames = [u.to(self.device) for u in input_frames] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_masks = [u.to(self.device) for u in input_masks] + ref_images_before = True if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) @@ -771,9 +774,9 @@ class WanAny2V: expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - if target_camera != None: + if extended_input_dim>=2: shape = list(target_shape[1:]) - shape[0] *= 2 + shape[extended_input_dim-2] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) @@ -901,8 +904,8 @@ class WanAny2V: for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor - if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) + if extended_input_dim > 0: + latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim) else: latent_model_input = latents @@ -1030,7 +1033,7 @@ class WanAny2V: if callback is not None: latents_preview = latents - if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1041,7 +1044,7 @@ class WanAny2V: if timestep_injection: latents[:, :, :source_latents.shape[2]] = source_latents - if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: latent_slice = latents[:, :, return_latent_slice].clone() @@ -1078,4 +1081,12 @@ class WanAny2V: delattr(model, "vace_blocks") + def adapt_animate_model(self, model): + modules_dict= { k: m for k, m in model.named_modules()} + for animate_layer in range(8): + module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"] + model_layer = animate_layer * 5 + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "face_adapter_fuser_blocks", module ) + delattr(model, "face_adapter") diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index d3dc783..c98087b 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -16,6 +16,9 @@ from mmgp.offload import get_cache, clear_caches from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target +from ..animate.motion_encoder import Generator +from ..animate.face_blocks import FaceAdapter, FaceEncoder +from ..animate.model_animate import after_patch_embedding __all__ = ['WanModel'] @@ -499,6 +502,7 @@ class WanAttentionBlock(nn.Module): multitalk_masks=None, ref_images_count=0, standin_phase=-1, + motion_vec = None, ): r""" Args: @@ -616,6 +620,10 @@ class WanAttentionBlock(nn.Module): x.add_(hint) else: x.add_(hint, alpha= scale) + + if motion_vec is not None and self.block_no % 5 == 0: + x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False) + return x class AudioProjModel(ModelMixin, ConfigMixin): @@ -898,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin): norm_input_visual=True, norm_output_audio=True, standin= False, + motion_encoder_dim=0, ): super().__init__() @@ -922,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin): self.flag_causal_attention = False self.block_mask = None self.inject_sample_info = inject_sample_info - + self.motion_encoder_dim = motion_encoder_dim self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale multitalk = multitalk_output_dim > 0 - self.multitalk = multitalk + self.multitalk = multitalk + animate = motion_encoder_dim > 0 # embeddings self.patch_embedding = nn.Conv3d( @@ -1027,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin): block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if animate: + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1208,6 +1237,9 @@ class WanModel(ModelMixin, ConfigMixin): ref_images_count = 0, standin_freqs = None, standin_ref = None, + pose_latents=None, + face_pixel_values=None, + ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1240,9 +1272,18 @@ class WanModel(ModelMixin, ConfigMixin): if bz > 1: y = y.expand(bz, -1, -1, -1, -1) x = torch.cat([x, y], dim=1) # embeddings - # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] + x_list[i] = x + y = None + + motion_vec_list = [] + for i, x in enumerate(x_list): + # animate embeddings + motion_vec = None + if pose_latents is not None: + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values) + motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) x_og_shape = x.shape @@ -1250,7 +1291,7 @@ class WanModel(ModelMixin, ConfigMixin): else: x = x.flatten(2).transpose(1, 2) x_list[i] = x - x, y = None, None + x = None block_mask = None @@ -1450,9 +1491,9 @@ class WanModel(ModelMixin, ConfigMixin): continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs) del x context = hints = audio_embedding = None diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 6586176..ff1570f 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -3,10 +3,10 @@ import numpy as np import gradio as gr def test_class_i2v(base_model_type): - return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] def text_oneframe_overlap(base_model_type): - return test_class_i2v(base_model_type) and not test_multitalk(base_model_type) + return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type) def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -17,6 +17,8 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_wan_5B(base_model_type): + return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @staticmethod @@ -36,7 +38,7 @@ class family_handler(): def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] elif base_model_type in ["i2v_2_2"]: def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] else: # i2v @@ -83,11 +85,13 @@ class family_handler(): vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] extra_model_def["vace_class"] = vace_class - if test_multitalk(base_model_type): + if base_model_type in ["animate"]: + fps = 30 + elif test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): fps = 24 else: fps = 16 @@ -100,14 +104,14 @@ class family_handler(): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "convert_image_guide_to_video" : True, @@ -146,6 +150,34 @@ class family_handler(): } # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["lucy_edit"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible": False, + } + + if base_model_type in ["animate"]: + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video Person in Reference Image", "PVBAI"), + ], + "default": "KI", + "letters_filter": "PVBXAKI", + "label": "Type of Process", + "show_label" : False, + } + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["extract_guide_from_window_start"] = True + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" + extra_model_def["background_ref_outpainted"] = False + + + if vace_class: extra_model_def["guide_preprocessing"] = { "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], @@ -157,16 +189,19 @@ class family_handler(): extra_model_def["image_ref_choices"] = { "choices": [("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), ], "letters_filter": "KFI", } extra_model_def["lock_image_refs_ratios"] = True - extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["pad_guide_video"] = True + extra_model_def["guide_inpaint_color"] = 127.5 + extra_model_def["forced_guide_mask_inputs"] = True if base_model_type in ["standin"]: extra_model_def["lock_image_refs_ratios"] = True @@ -209,10 +244,12 @@ class family_handler(): "visible" : False, } - if vace_class or base_model_type in ["infinitetalk"]: + if vace_class or base_model_type in ["infinitetalk", "animate"]: image_prompt_types_allowed = "TVL" elif base_model_type in ["ti2v_2_2"]: image_prompt_types_allowed = "TSVL" + elif base_model_type in ["lucy_edit"]: + image_prompt_types_allowed = "TVL" elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: image_prompt_types_allowed = "SVL" elif i2v: @@ -234,8 +271,8 @@ class family_handler(): def query_supported_types(): return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @staticmethod @@ -265,11 +302,12 @@ class family_handler(): @staticmethod def get_vae_block_size(base_model_type): - return 32 if base_model_type == "ti2v_2_2" else 16 + return 32 if test_wan_5B(base_model_type) else 16 @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors + if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2" latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @@ -283,7 +321,7 @@ class family_handler(): "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] }] - if base_model_type == "ti2v_2_2": + if test_wan_5B(base_model_type): download_def += [ { "repoId" : "DeepBeepMeep/Wan2.2", "sourceFolderList" : [""], @@ -377,8 +415,8 @@ class family_handler(): ui_defaults.update({ "sample_solver": "unipc", }) - if test_class_i2v(base_model_type): - ui_defaults["image_prompt_type"] = "S" + if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]: + ui_defaults["image_prompt_type"] = "S" if base_model_type in ["fantasy"]: ui_defaults.update({ @@ -434,10 +472,15 @@ class family_handler(): "image_prompt_type": "T", }) - if base_model_type in ["recam_1.3B"]: + if base_model_type in ["recam_1.3B", "lucy_edit"]: ui_defaults.update({ "video_prompt_type": "UV", }) + elif base_model_type in ["animate"]: + ui_defaults.update({ + "video_prompt_type": "PVBXAKI", + "mask_expand": 20, + }) if text_oneframe_overlap(base_model_type): ui_defaults["sliding_window_overlap"] = 1 diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py new file mode 100644 index 0000000..608b176 --- /dev/null +++ b/shared/convert/convert_diffusers_to_flux.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" +Convert a Flux model from Diffusers (folder or single-file) into the original +single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. + +Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) +Output : /path/to/flux1-your-model.safetensors (transformer only) + +Usage: + python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors + python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors + # optional quantization: + # --fp8 (float8_e4m3fn, simple) + # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) +""" + +import argparse +import json +from pathlib import Path +from collections import OrderedDict + +import torch +from safetensors import safe_open +import safetensors.torch +from tqdm import tqdm + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("diffusers_path", type=str, + help="Path to Diffusers checkpoint folder OR a single .safetensors file.") + ap.add_argument("output_path", type=str, + help="Output .safetensors path for the Flux transformer.") + ap.add_argument("--fp8", action="store_true", + help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") + ap.add_argument("--fp8-scaled", action="store_true", + help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") + return ap.parse_args() + + +# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable). +DIFFUSERS_MAP = { + # global embeds + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + + # dual-stream (image/text) blocks + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + + # single-stream blocks + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + + # final + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift] + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +class DiffusersSource: + """ + Uniform interface over: + 1) Folder with index JSON + shards + 2) Folder with exactly one .safetensors (no index) + 3) Single .safetensors file + Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) + """ + + POSSIBLE_PREFIXES = ["", "model."] # try in this order + + def __init__(self, path: Path): + p = Path(path) + if p.is_dir(): + # use 'transformer' subfolder if present + if (p / "transformer").is_dir(): + p = p / "transformer" + self._init_from_dir(p) + elif p.is_file() and p.suffix == ".safetensors": + self._init_from_single_file(p) + else: + raise FileNotFoundError(f"Invalid path: {p}") + + # ---------- common helpers ---------- + + @staticmethod + def _strip_prefix(k: str) -> str: + return k[6:] if k.startswith("model.") else k + + def _resolve(self, want: str): + """ + Return the actual stored key matching `want` by trying known prefixes. + """ + for pref in self.POSSIBLE_PREFIXES: + k = pref + want + if k in self._all_keys: + return k + return None + + def has(self, want: str) -> bool: + return self._resolve(want) is not None + + def get(self, want: str) -> torch.Tensor: + real_key = self._resolve(want) + if real_key is None: + raise KeyError(f"Missing key: {want}") + return self._get_by_real_key(real_key).to("cpu") + + @property + def base_keys(self): + # keys without 'model.' prefix for scanning + return [self._strip_prefix(k) for k in self._all_keys] + + # ---------- modes ---------- + + def _init_from_single_file(self, file_path: Path): + self._mode = "single" + self._file = file_path + self._handle = safe_open(file_path, framework="pt", device="cpu") + self._all_keys = list(self._handle.keys()) + + def _get_by_real_key(real_key: str): + return self._handle.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + + def _init_from_dir(self, dpath: Path): + index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" + if index_json.exists(): + with open(index_json, "r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] # full mapping + self._mode = "sharded" + self._dpath = dpath + self._weight_map = {k: dpath / v for k, v in weight_map.items()} + self._all_keys = list(self._weight_map.keys()) + self._open_handles = {} + + def _get_by_real_key(real_key: str): + fpath = self._weight_map[real_key] + h = self._open_handles.get(fpath) + if h is None: + h = safe_open(fpath, framework="pt", device="cpu") + self._open_handles[fpath] = h + return h.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + return + + # no index: try exactly one safetensors in folder + files = sorted(dpath.glob("*.safetensors")) + if len(files) != 1: + raise FileNotFoundError( + f"No index found and {dpath} does not contain exactly one .safetensors file." + ) + self._init_from_single_file(files[0]) + + +def main(): + args = parse_args() + src = DiffusersSource(Path(args.diffusers_path)) + + # Count blocks by scanning base keys (with any 'model.' prefix removed) + num_dual = 0 + num_single = 0 + for k in src.base_keys: + if k.startswith("transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_dual = max(num_dual, i + 1) + except Exception: + pass + elif k.startswith("single_transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_single = max(num_single, i + 1) + except Exception: + pass + print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") + + # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0) + def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: + shift, scale = vec.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + orig = {} + + # Per-block (dual) + for b in range(num_dual): + prefix = f"transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("double_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Per-block (single) + for b in range(num_single): + prefix = f"single_transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("single_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Globals (non-block) + for okey, dvals in DIFFUSERS_MAP.items(): + if okey.startswith(("double_blocks.", "single_blocks.")): + continue + dkeys = dvals + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey] = src.get(dkeys[0]) + else: + orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves + if "final_layer.adaLN_modulation.1.weight" in orig: + orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.weight"] + ) + if "final_layer.adaLN_modulation.1.bias" in orig: + orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.bias"] + ) + + # Optional FP8 variants (experimental; not required for ComfyUI/BFL) + if args.fp8 or args.fp8_scaled: + dtype = torch.float8_e4m3fn # noqa + minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max + + def stochastic_round_to(t): + t = t.float().clamp(minv, maxv) + lower = torch.floor(t * 256) / 256 + upper = torch.ceil(t * 256) / 256 + prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) + rnd = torch.rand_like(t) + out = torch.where(rnd < prob, upper, lower) + return out.to(dtype) + + def scale_to_8bit(weight, target_max=416.0): + absmax = weight.abs().max() + scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) + scaled = (weight / scale).clamp(minv, maxv).to(dtype) + return scaled, scale + + scales = {} + for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): + t = orig[k] + if args.fp8: + orig[k] = stochastic_round_to(t) + else: + if k.endswith(".weight") and t.dim() == 2: + qt, s = scale_to_8bit(t) + orig[k] = qt + scales[k[:-len(".weight")] + ".scale_weight"] = s + else: + orig[k] = t.clamp(minv, maxv).to(dtype) + if args.fp8_scaled: + orig.update(scales) + orig["scaled_fp8"] = torch.tensor([], dtype=dtype) + else: + # Default: save in bfloat16 + for k in list(orig.keys()): + orig[k] = orig[k].to(torch.bfloat16).cpu() + + out_path = Path(args.output_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + meta = OrderedDict() + meta["format"] = "pt" + meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") + print(f"Saving transformer to: {out_path}") + safetensors.torch.save_file(orig, str(out_path), metadata=meta) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py new file mode 100644 index 0000000..3165e7b --- /dev/null +++ b/shared/inpainting/lanpaint.py @@ -0,0 +1,240 @@ +import torch +from .utils import * +from functools import partial + +# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) + +def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape + + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + +def _unpack_latents(latents, height, width, vae_scale_factor=8): + batch_size, num_patches, channels = latents.shape + + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + +class LanPaint(): + def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): + self.n_steps = NSteps + self.chara_lamb = Lambda + self.IS_FLUX = IS_FLUX + self.IS_FLOW = IS_FLOW + self.step_size = StepSize + self.friction = Friction + self.chara_beta = Beta + self.img_dim_size = None + def add_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (None,) * (self.img_dim_size-1) + return array[index] + def remove_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (0,) * (self.img_dim_size-1) + return array[index] + def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): + latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) + noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) + x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) + latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.img_dim_size = len(x.shape) + self.latent_image = latent_image + self.noise = noise + if n_steps is None: + n_steps = self.n_steps + out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) + out = _pack_latents(out) + return out + def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): + if IS_FLUX: + cfg_BIG = 1.0 + + def double_denoise(latents, t): + latents = _pack_latents(latents) + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None, None + predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) + if true_cfg_scale == cfg_BIG: + predict_big = predict_std + else: + predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) + predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) + return predict_std, predict_big + + if len(sigma.shape) == 0: + sigma = torch.tensor([sigma.item()]) + latent_mask = 1 - latent_mask + if IS_FLUX or IS_FLOW: + Flow_t = sigma + abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) + VE_Sigma = Flow_t / (1 - Flow_t) + #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) + else: + VE_Sigma = sigma + abt = 1/( 1+VE_Sigma**2 ) + Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) + # VE_Sigma, abt, Flow_t = current_times + current_times = (VE_Sigma, abt, Flow_t) + + step_size = self.step_size * (1 - abt) + step_size = self.add_none_dims(step_size) + # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values + # This is the replace step + # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask + + noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma + x = x * (1 - latent_mask) + noisy_image * latent_mask + + if IS_FLUX or IS_FLOW: + x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + + ############ LanPaint Iterations Start ############### + # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. + args = None + for i in range(n_steps): + score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) + if score_func is None: return None + x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) + if IS_FLUX or IS_FLOW: + x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + ############ LanPaint Iterations End ############### + # out is x_0 + # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) + # out = out * (1-latent_mask) + self.latent_image * latent_mask + # return out + return x + + def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): + + lamb = self.chara_lamb + if self.IS_FLUX or self.IS_FLOW: + # compute t for flow model, with a small epsilon compensating for numerical error. + x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) + if x_0 is None: return None + else: + x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) + if x_0 is None: return None + + score_x = -(x_t - x_0) + score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) + return score_x * (1 - mask) + score_y * mask + def sigma_x(self, abt): + # the time scale for the x_t update + return abt**0 + def sigma_y(self, abt): + beta = self.chara_beta * abt ** 0 + return beta + + def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): + # prepare the step size and time parameters + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) + sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes + # print('mask',mask.device) + if torch.mean(dtx) <= 0.: + return x_t, args + # ------------------------------------------------------------------------- + # Compute the Langevin dynamics update in variance perserving notation + # ------------------------------------------------------------------------- + #x0 = self.x0_evalutation(x_t, score, sigma, args) + #C = abt**0.5 * x0 / (1-abt) + A = A_x * (1-mask) + A_y * mask + D = D_x * (1-mask) + D_y * mask + dt = dtx * (1-mask) + dty * mask + Gamma = Gamma_x * (1-mask) + Gamma_y * mask + + + def Coef_C(x_t): + x0 = self.x0_evalutation(x_t, score, sigma, args) + C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t + return C + def advance_time(x_t, v, dt, Gamma, A, C, D): + dtype = x_t.dtype + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + osc = StochasticHarmonicOscillator(Gamma, A, C, D ) + x_t, v = osc.dynamics(x_t, v, dt ) + x_t = x_t.to(dtype) + v = v.to(dtype) + return x_t, v + if args is None: + #v = torch.zeros_like(x_t) + v = None + C = Coef_C(x_t) + #print(torch.squeeze(dtx), torch.squeeze(dty)) + x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) + else: + v, C = args + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C_new = Coef_C(x_t) + v = v + Gamma**0.5 * ( C_new - C) *dt + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C = C_new + + return x_t, (v, C) + + def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): + # ------------------------------------------------------------------------- + # Unpack current times parameters (sigma and abt) + sigma, abt, flow_t = current_times + sigma = self.add_none_dims(sigma) + abt = self.add_none_dims(abt) + # Compute time step (dtx, dty) for x and y branches. + dtx = 2 * step_size * sigma_x + dty = 2 * step_size * sigma_y + + # ------------------------------------------------------------------------- + # Define friction parameter Gamma_hat for each branch. + # Using dtx**0 provides a tensor of the proper device/dtype. + + Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 + Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 + #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) + # adjust dt to match denoise-addnoise steps sizes + Gamma_hat_x /= 2. + Gamma_hat_y /= 2. + A_t_x = (1) / ( 1 - abt ) * dtx / 2 + A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 + + + A_x = A_t_x / (dtx/2) + A_y = A_t_y / (dty/2) + Gamma_x = Gamma_hat_x / (dtx/2) + Gamma_y = Gamma_hat_y / (dty/2) + + #D_x = (2 * (1 + sigma**2) )**0.5 + #D_y = (2 * (1 + sigma**2) )**0.5 + D_x = (2 * abt**0 )**0.5 + D_y = (2 * abt**0 )**0.5 + return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y + + + + def x0_evalutation(self, x_t, score, sigma, args): + x0 = x_t + score(x_t) + return x0 \ No newline at end of file diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py new file mode 100644 index 0000000..c017ab0 --- /dev/null +++ b/shared/inpainting/utils.py @@ -0,0 +1,301 @@ +import torch +def epxm1_x(x): + # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. + result = torch.special.expm1(x) / x + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x) < 1e-2 + result = torch.where(mask, 1 + x/2. + x**2 / 6., result) + return result +def epxm1mx_x2(x): + # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x) / x**2 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**2) < 1e-2 + result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) + return result + +def expm1mxmhx2_x3(x): + # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**3) < 1e-2 + result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) + return result + +def exp_1mcosh_GD(gamma_t, delta): + """ + Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) + numerator = torch.where(is_positive, numerator_pos, numerator_neg) + result = numerator / (delta * gamma_t**2 ) + # Handle NaN/inf cases + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + # Handle numerical instability for small delta + mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 + taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) + result = torch.where(mask, taylor, result) + return result + +def exp_sinh_GsqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + denominator_pos = gamma_t_sqrt_delta + result_pos = numerator_pos / gamma_t_sqrt_delta + result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) + + # Taylor expansion for small gamma_t_sqrt_delta + mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 + taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) + result_pos = torch.where(mask, taylor, result_pos) + + # Handle negative delta + result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) + result = torch.where(is_positive, result_pos, result_neg) + return result + +def exp_cosh(gamma_t, delta): + """ + Compute e^(-Γt) * cosh(Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result + return result +def exp_sinh_sqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / √Δ + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + Returns: + Result of the computation with numerical stability handling + """ + exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + result = gamma_t * exp_sinh_GsqrtD_result + return result + + + +def zeta1(gamma_t, delta): + # Compute hyperbolic terms and exponential + half_gamma_t = gamma_t / 2 + exp_cosh_term = exp_cosh(half_gamma_t, delta) + exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) + + + # Main computation + numerator = 1 - (exp_cosh_term + exp_sinh_term) + denominator = gamma_t * (1 - delta) / 4 + result = 1 - numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small x (similar to your epxm1Dx approach) + mask = torch.abs(denominator) < 5e-3 + term1 = epxm1_x(-gamma_t) + term2 = epxm1mx_x2(-gamma_t) + term3 = expm1mxmhx2_x3(-gamma_t) + taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 + result = torch.where(mask, taylor, result) + + return result + +def exp_cosh_minus_terms(gamma_t, delta): + """ + Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_term = torch.exp(-gamma_t) + # Compute individual terms + exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term + exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term + + #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + # Main computation + numerator = exp_cosh_term - exp_cosh_delta_term + denominator = gamma_t * (1 - delta) + + result = numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small gamma_t and delta near 1 + mask = (torch.abs(denominator) < 1e-1) + exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) + taylor = ( + gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) + - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) + ) + result = torch.where(mask, taylor, result) + + return result + + +def zeta2(gamma_t, delta): + half_gamma_t = gamma_t / 2 + return exp_sinh_GsqrtD(half_gamma_t, delta) + +def sig11(gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + + +def Zcoefs(gamma_t, delta): + Zeta1 = zeta1(gamma_t, delta) + Zeta2 = zeta2(gamma_t, delta) + + sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 + amplitude = torch.sqrt(sq_total) + Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude + Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 + #cterm = exp_cosh_minus_terms(gamma_t, delta) + #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) + #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) + Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) + + return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude + +def Zcoefs_asymp(gamma_t, delta): + A_t = (gamma_t * (1 - delta) )/4 + return epxm1_x(- 2 * A_t) + +class StochasticHarmonicOscillator: + """ + Simulates a stochastic harmonic oscillator governed by the equations: + dy(t) = q(t) dt + dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt + + Also define v(t) = q(t) / √Γ, which is numerically more stable. + + Where: + y(t) - Position variable + q(t) - Velocity variable + Γ - Damping coefficient + A - Harmonic potential strength + C - Constant force term + D - Noise amplitude + dw(t) - Wiener process (Brownian motion) + """ + def __init__(self, Gamma, A, C, D): + self.Gamma = Gamma + self.A = A + self.C = C + self.D = D + self.Delta = 1 - 4 * A / Gamma + def sig11(self, gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + def sig22(self, gamma_t, delta): + return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) + def dynamics(self, y0, v0, t): + """ + Calculates the position and velocity variables at time t. + + Parameters: + y0 (float): Initial position + v0 (float): Initial velocity v(0) = q(0) / √Γ + t (float): Time at which to evaluate the dynamics + Returns: + tuple: (y(t), v(t)) + """ + + dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 + Delta = self.Delta + dummyzero + Gamma_hat = self.Gamma * t + dummyzero + A = self.A + dummyzero + C = self.C + dummyzero + D = self.D + dummyzero + Gamma = self.Gamma + dummyzero + zeta_1 = zeta1( Gamma_hat, Delta) + zeta_2 = zeta2( Gamma_hat, Delta) + EE = 1 - Gamma_hat * zeta_2 + + if v0 is None: + v0 = torch.randn_like(y0) * D / 2 ** 0.5 + #v0 = (C - A * y0)/Gamma**0.5 + + # Calculate mean position and velocity + term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t + y_mean = term1 + y0 + v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 + + cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) + cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 + cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) + + # sample new position and velocity with multivariate normal distribution + + batch_shape = y0.shape + cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + cov_matrix[..., 0, 0] = cov_yy + cov_matrix[..., 0, 1] = cov_yv + cov_matrix[..., 1, 0] = cov_yv # symmetric + cov_matrix[..., 1, 1] = cov_vv + + + + # Compute the Cholesky decomposition to get scale_tril + #scale_tril = torch.linalg.cholesky(cov_matrix) + scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + tol = 1e-8 + cov_yy = torch.clamp( cov_yy, min = tol ) + sd_yy = torch.sqrt( cov_yy ) + inv_sd_yy = 1/(sd_yy) + + scale_tril[..., 0, 0] = sd_yy + scale_tril[..., 0, 1] = 0. + scale_tril[..., 1, 0] = cov_yv * inv_sd_yy + scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 + # check if it matches torch.linalg. + #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) + # Sample correlated noise from multivariate normal + mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) + mean[..., 0] = y_mean + mean[..., 1] = v_mean + new_yv = torch.distributions.MultivariateNormal( + loc=mean, + scale_tril=scale_tril + ).sample() + + return new_yv[...,0], new_yv[...,1] \ No newline at end of file diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py index b24530d..224cf34 100644 --- a/shared/utils/audio_video.py +++ b/shared/utils/audio_video.py @@ -232,6 +232,9 @@ def save_video(tensor, retry=5): """Save tensor as video with configurable codec and container options.""" + if torch.is_tensor(tensor) and len(tensor.shape) == 4: + tensor = tensor.unsqueeze(0) + suffix = f'.{container}' cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file if not cache_file.endswith(suffix): diff --git a/shared/utils/download.py b/shared/utils/download.py new file mode 100644 index 0000000..ed035c0 --- /dev/null +++ b/shared/utils/download.py @@ -0,0 +1,110 @@ +import sys, time + +# Global variables to track download progress +_start_time = None +_last_time = None +_last_downloaded = 0 +_speed_history = [] +_update_interval = 0.5 # Update speed every 0.5 seconds + +def progress_hook(block_num, block_size, total_size, filename=None): + """ + Simple progress bar hook for urlretrieve + + Args: + block_num: Number of blocks downloaded so far + block_size: Size of each block in bytes + total_size: Total size of the file in bytes + filename: Name of the file being downloaded (optional) + """ + global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval + + current_time = time.time() + downloaded = block_num * block_size + + # Initialize timing on first call + if _start_time is None or block_num == 0: + _start_time = current_time + _last_time = current_time + _last_downloaded = 0 + _speed_history = [] + + # Calculate download speed only at specified intervals + speed = 0 + if current_time - _last_time >= _update_interval: + if _last_time > 0: + current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) + _speed_history.append(current_speed) + # Keep only last 5 speed measurements for smoothing + if len(_speed_history) > 5: + _speed_history.pop(0) + # Average the recent speeds for smoother display + speed = sum(_speed_history) / len(_speed_history) + + _last_time = current_time + _last_downloaded = downloaded + elif _speed_history: + # Use the last calculated average speed + speed = sum(_speed_history) / len(_speed_history) + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + file_display = filename if filename else "Unknown file" + + if total_size <= 0: + # If total size is unknown, show downloaded bytes + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(80)) + sys.stdout.flush() + return + + downloaded = block_num * block_size + percent = min(100, (downloaded / total_size) * 100) + + # Create progress bar (40 characters wide to leave room for other info) + bar_length = 40 + filled = int(bar_length * percent / 100) + bar = '█' * filled + '░' * (bar_length - filled) + + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + + # Display progress with filename first + line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(100)) + sys.stdout.flush() + + # Print newline when complete + if percent >= 100: + print() + +# Wrapper function to include filename in progress hook +def create_progress_hook(filename): + """Creates a progress hook with the filename included""" + global _start_time, _last_time, _last_downloaded, _speed_history + # Reset timing variables for new download + _start_time = None + _last_time = None + _last_downloaded = 0 + _speed_history = [] + + def hook(block_num, block_size, total_size): + return progress_hook(block_num, block_size, total_size, filename) + return hook + + diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 17b9dde..6e8a98b 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os import os.path as osp @@ -176,8 +175,9 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = -1): - t = t[:, frame_no] if frame_no >= 0 else t +def convert_tensor_to_image(t, frame_no = 0): + if len(t.shape) == 4: + t = t[:, frame_no] return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): @@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): if rm_background: session = new_session() output_list =[] + output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size + resized_mask = None if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: - if outpainting_dims is not None: - resized_image =img + if outpainting_dims is not None and background_ref_outpainted: + resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) else: @@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, - return output_list + output_mask_list.append(resized_mask) + return output_list, output_mask_list +def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): + from shared.utils.utils import save_image + inpaint_color = canvas_tf_bg / 127.5 - 1 - - -def str2bool(v): - """ - Convert a string to a boolean. - - Supported true values: 'yes', 'true', 't', 'y', '1' - Supported false values: 'no', 'false', 'f', 'n', '0' - - Args: - v (str): String to convert. - - Returns: - bool: Converted boolean value. - - Raises: - argparse.ArgumentTypeError: If the value cannot be converted to boolean. - """ - if isinstance(v, bool): - return v - v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): - return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): - return False + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size and outpainting_dims == None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + canvas = torch.zeros_like(ref_img) if return_mask else None else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') - - -import sys, time - -# Global variables to track download progress -_start_time = None -_last_time = None -_last_downloaded = 0 -_speed_history = [] -_update_interval = 0.5 # Update speed every 0.5 seconds - -def progress_hook(block_num, block_size, total_size, filename=None): - """ - Simple progress bar hook for urlretrieve - - Args: - block_num: Number of blocks downloaded so far - block_size: Size of each block in bytes - total_size: Total size of the file in bytes - filename: Name of the file being downloaded (optional) - """ - global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval - - current_time = time.time() - downloaded = block_num * block_size - - # Initialize timing on first call - if _start_time is None or block_num == 0: - _start_time = current_time - _last_time = current_time - _last_downloaded = 0 - _speed_history = [] - - # Calculate download speed only at specified intervals - speed = 0 - if current_time - _last_time >= _update_interval: - if _last_time > 0: - current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) - _speed_history.append(current_speed) - # Keep only last 5 speed measurements for smoothing - if len(_speed_history) > 5: - _speed_history.pop(0) - # Average the recent speeds for smoother display - speed = sum(_speed_history) / len(_speed_history) - - _last_time = current_time - _last_downloaded = downloaded - elif _speed_history: - # Use the last calculated average speed - speed = sum(_speed_history) / len(_speed_history) - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - file_display = filename if filename else "Unknown file" - - if total_size <= 0: - # If total size is unknown, show downloaded bytes - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(80)) - sys.stdout.flush() - return - - downloaded = block_num * block_size - percent = min(100, (downloaded / total_size) * 100) - - # Create progress bar (40 characters wide to leave room for other info) - bar_length = 40 - filled = int(bar_length * percent / 100) - bar = '█' * filled + '░' * (bar_length - filled) - - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - - # Display progress with filename first - line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(100)) - sys.stdout.flush() - - # Print newline when complete - if percent >= 100: - print() - -# Wrapper function to include filename in progress hook -def create_progress_hook(filename): - """Creates a progress hook with the filename included""" - global _start_time, _last_time, _last_downloaded, _speed_history - # Reset timing variables for new download - _start_time = None - _last_time = None - _last_downloaded = 0 - _speed_history = [] - - def hook(block_num, block_size, total_size): - return progress_hook(block_num, block_size, total_size, filename) - return hook + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + else: + canvas_height, canvas_width = image_size + if full_frame: + new_height = canvas_height + new_width = canvas_width + top = left = 0 + else: + # if fill_max and (canvas_height - new_height) < 16: + # new_height = canvas_height + # if fill_max and (canvas_width - new_width) < 16: + # new_width = canvas_width + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + if return_image: + return convert_tensor_to_image(ref_img), canvas + + return ref_img.to(device), canvas + +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): + src_videos, src_masks = [], [] + inpaint_color = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): + src_video = src_mask = None + if cur_video_guide is not None: + src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w + if cur_video_mask is not None and any_mask: + src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w + if pre_video_guide is not None and not extract_guide_from_window_start: + src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) + if any_mask: + src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) + if src_video is None: + if any_guide_padding: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") + if any_mask: + src_mask = torch.zeros_like(src_video[0:1]) + elif src_video.shape[1] < current_video_length: + if any_guide_padding: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) + if cur_video_mask is not None and any_mask: + src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + if any_mask: + src_mask = src_mask[:, :new_num_frames] + + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color + src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + + src_videos.append(src_video) + src_masks.append(src_mask) + return src_videos, src_masks diff --git a/wgp.py b/wgp.py index ee6b5f6..0e6a0b2 100644 --- a/wgp.py +++ b/wgp.py @@ -394,7 +394,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: if len(frames_positions.strip()) > 0: - positions = frames_positions.split(" ") + positions = frames_positions.replace(","," ").split(" ") for pos_str in positions: if not pos_str in ["L", "l"] and len(pos_str)>0: if not is_integer(pos_str): @@ -2528,7 +2528,7 @@ def download_models(model_filename = None, model_type= None, module_type = False from urllib.request import urlretrieve - from shared.utils.utils import create_progress_hook + from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", @@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva input_mask = convert_tensor_to_image(full_frame) return input_image, input_mask +def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): + if not input_video_path or max_frames <= 0: + return None, None + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + + any_mask = input_mask_path != None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if len(video) == 0: return None + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + frame_height, frame_width, _ = video[0].shape + + num_frames = min(len(video), len(mask_video)) + if num_frames == 0: return None + video, mask_video = video[:num_frames], mask_video[:num_frames] + + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + + face_list = [] + for frame_idx in range(num_frames): + frame = video[frame_idx].cpu().numpy() + # video[frame_idx] = None + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) + # mask_video[frame_idx] = None + if (frame_width, frame_height) != mask.size: + mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8) + alpha_mask[mask > 127] = 1 + frame = frame * alpha_mask + frame = Image.fromarray(frame) + face = face_processor.process(frame, resize_to=size, face_crop_scale = 1) + face_list.append(face) + + face_processor = None + gc.collect() + torch.cuda.empty_cache() + + face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w + if pad_frames > 0: + face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2) + + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_faces_frames = [np.array(face) for face in face_list ] + save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) + return face_tensor + def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3742,7 +3796,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box - + inpaint_color = int(inpaint_color) + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None @@ -3909,6 +3969,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, preproc_outside = None gc.collect() torch.cuda.empty_cache() + if pad_frames > 0: + masked_frames = masked_frames[0] * pad_frames + masked_frames + if any_mask: masked_frames = masks[0] * pad_frames + masks return torch.stack(masked_frames), torch.stack(masks) if any_mask else None @@ -4646,7 +4709,8 @@ def generate_video( current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 - + guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) + extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) i2v = test_class_i2v(model_type) diffusion_forcing = "diffusion_forcing" in model_filename t2v = base_model_type in ["t2v"] @@ -4662,6 +4726,7 @@ def generate_video( multitalk = model_def.get("multitalk_class", False) standin = model_def.get("standin_class", False) infinitetalk = base_model_type in ["infinitetalk"] + animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4822,7 +4887,6 @@ def generate_video( repeat_no = 0 extra_generation = 0 initial_total_windows = 0 - discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length if sliding_window: @@ -4843,7 +4907,7 @@ def generate_video( if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None + src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4899,7 +4963,7 @@ def generate_video( return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images = image_refs + src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -4943,16 +5007,52 @@ def generate_video( from models.wan.multitalk.multitalk import get_window_audio_embeddings # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) - if vace: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None + + if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.replace(","," ").split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + if video_guide is not None: - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) + keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] + guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] + keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] + guide_frames_extract_count = len(keep_frames_parsed) - if vace: + if "B" in video_prompt_type: + send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) + src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) + if src_faces is not None and src_faces.shape[1] < current_video_length: + src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) + + if vace or animate: + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None context_scale = [ control_net_weight] if "V" in video_prompt_type: process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) @@ -4971,10 +5071,10 @@ def generate_video( if preprocess_type2 is not None: context_scale = [ control_net_weight /2, control_net_weight2 /2] send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127 - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) + inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) if video_guide_processed != None: if sample_fit_canvas != None: @@ -4985,7 +5085,37 @@ def generate_video( refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] if video_mask_processed != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - elif ltxv: + + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + + if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): + any_mask = True + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], + [video_mask_processed, video_mask_processed2], + pre_video_guide, image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) + else: + src_faces = src_faces[:, :src_video.shape[1]] + if args.save_masks: + save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: + save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + + elif ltxv: preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) @@ -5023,7 +5153,7 @@ def generate_video( sample_fit_canvas = None else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps) + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) if video_guide_processed is None: src_video = pre_video_guide else: @@ -5043,29 +5173,6 @@ def generate_video( refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: - if repeat_no == 1: - frames_positions_list = [] - if frames_positions is not None and len(frames_positions)> 0: - positions = frames_positions.split(" ") - cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0 - last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count - joker_used = False - project_window_no = 1 - for pos in positions : - if len(pos) > 0: - if pos in ["L", "l"]: - cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: - joker_used = True - cur_end_pos = last_frame_no -1 - project_window_no += 1 - frames_positions_list.append(cur_end_pos) - cur_end_pos -= sliding_window_discard_last_frames + reuse_frames - else: - frames_positions_list.append(int(pos)-1 + alignment_shift) - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) - if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : from shared.utils.utils import get_outpainting_full_area_dimensions w, h = image_refs[0].size @@ -5089,20 +5196,16 @@ def generate_video( if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, block_size=block_size, - outpainting_dims =outpainting_dims ) + outpainting_dims =outpainting_dims, + background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) refresh_preview["image_refs"] = image_refs - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] if vace : - frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame] image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], @@ -5116,7 +5219,7 @@ def generate_video( any_background_ref = any_background_ref ) if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] if any_background_ref: new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] else: @@ -5165,10 +5268,14 @@ def generate_video( input_prompt = prompt, image_start = image_start_tensor, image_end = image_end_tensor, - input_frames = src_video, + input_frames = src_video, + input_frames2 = src_video2, input_ref_images= src_ref_images, + input_ref_masks = src_ref_masks, input_masks = src_mask, + input_masks2 = src_mask2, input_video= pre_video_guide, + input_faces = src_faces, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, @@ -5302,6 +5409,7 @@ def generate_video( send_cmd("output") else: sample = samples.cpu() + abort = not is_image and sample.shape[1] < current_video_length # if True: # for testing # torch.save(sample, "output.pt") # else: @@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): old_video_prompt_type = video_prompt_type - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV") + video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type model_type = state["model_type"] @@ -7437,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) image_prompt_type_choices = [] if "T" in image_prompt_types_allowed: - image_prompt_type_choices += [("Text Prompt Only", "")] + image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")] if "S" in image_prompt_types_allowed: image_prompt_type_choices += [("Start Video with Image", "S")] any_start_image = True @@ -7516,7 +7624,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( guide_preprocessing_choices, - value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ), + value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ), label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, ) any_control_video = True @@ -7560,13 +7668,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non } mask_preprocessing_choices = [] - mask_preprocessing_labels = guide_preprocessing.get("labels", {}) + mask_preprocessing_labels = mask_preprocessing.get("labels", {}) for process_type in mask_preprocessing["selection"]: process_label = mask_preprocessing_labels.get(process_type, None) process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label mask_preprocessing_choices.append( (process_label, process_type) ) - video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed") + video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( mask_preprocessing_choices, value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), @@ -7591,7 +7699,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non choices= image_ref_choices["choices"], value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), visible = image_ref_choices.get("visible", True), - label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2 ) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) @@ -7634,7 +7742,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7649,14 +7757,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) - image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "") image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None - background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects") + background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects") remove_background_images_ref = gr.Dropdown( choices=[ @@ -7664,7 +7772,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal + label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal ) any_audio_voices_support = any_audio_track(base_model_type) @@ -8084,7 +8192,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue", + label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk )