mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
commit in case there is an unrecoverable code hemorragy
This commit is contained in:
parent
fc615ffb3c
commit
84010bd861
15
configs/animate.json
Normal file
15
configs/animate.json
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "WanModel",
|
||||||
|
"_diffusers_version": "0.30.0",
|
||||||
|
"dim": 5120,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"ffn_dim": 13824,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": 36,
|
||||||
|
"model_type": "i2v",
|
||||||
|
"num_heads": 40,
|
||||||
|
"num_layers": 40,
|
||||||
|
"out_dim": 16,
|
||||||
|
"text_len": 512,
|
||||||
|
"motion_encoder_dim": 512
|
||||||
|
}
|
||||||
14
configs/lucy_edit.json
Normal file
14
configs/lucy_edit.json
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"_class_name": "WanModel",
|
||||||
|
"_diffusers_version": "0.33.0",
|
||||||
|
"dim": 3072,
|
||||||
|
"eps": 1e-06,
|
||||||
|
"ffn_dim": 14336,
|
||||||
|
"freq_dim": 256,
|
||||||
|
"in_dim": 96,
|
||||||
|
"model_type": "ti2v2_2",
|
||||||
|
"num_heads": 24,
|
||||||
|
"num_layers": 30,
|
||||||
|
"out_dim": 48,
|
||||||
|
"text_len": 512
|
||||||
|
}
|
||||||
13
defaults/animate.json
Normal file
13
defaults/animate.json
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Wan2.2 Animate",
|
||||||
|
"architecture": "animate",
|
||||||
|
"description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"group": "wan2_2"
|
||||||
|
}
|
||||||
|
}
|
||||||
18
defaults/lucy_edit.json
Normal file
18
defaults/lucy_edit.json
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
{
|
||||||
|
"model": {
|
||||||
|
"name": "Wan2.2 Lucy Edit 5B",
|
||||||
|
"architecture": "lucy_edit",
|
||||||
|
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
|
||||||
|
"URLs": [
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
|
||||||
|
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
|
||||||
|
],
|
||||||
|
"group": "wan2_2"
|
||||||
|
},
|
||||||
|
"video_length": 81,
|
||||||
|
"guidance_scale": 5,
|
||||||
|
"flow_shift": 5,
|
||||||
|
"num_inference_steps": 30,
|
||||||
|
"resolution": "1280x720"
|
||||||
|
}
|
||||||
@ -7,6 +7,7 @@
|
|||||||
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
|
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
|
||||||
"group": "wan2_2"
|
"group": "wan2_2"
|
||||||
},
|
},
|
||||||
|
"prompt" : "Put the person into a clown outfit.",
|
||||||
"video_length": 121,
|
"video_length": 121,
|
||||||
"guidance_scale": 1,
|
"guidance_scale": 1,
|
||||||
"flow_shift": 3,
|
"flow_shift": 3,
|
||||||
|
|||||||
@ -11,7 +11,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
@ -387,7 +386,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
return latent_image_ids.to(device=device, dtype=dtype)
|
return latent_image_ids.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
def _pack_latents(latents):
|
||||||
|
batch_size, num_channels_latents, _, height, width = latents.shape
|
||||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||||
@ -479,7 +479,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||||
|
|
||||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
shape = (batch_size, num_channels_latents, 1, height, width)
|
||||||
|
|
||||||
image_latents = None
|
image_latents = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
@ -499,10 +499,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
else:
|
else:
|
||||||
image_latents = torch.cat([image_latents], dim=0)
|
image_latents = torch.cat([image_latents], dim=0)
|
||||||
|
|
||||||
image_latent_height, image_latent_width = image_latents.shape[3:]
|
image_latents = self._pack_latents(image_latents)
|
||||||
image_latents = self._pack_latents(
|
|
||||||
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(generator, list) and len(generator) != batch_size:
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -511,7 +508,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
)
|
)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
latents = self._pack_latents(latents)
|
||||||
else:
|
else:
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -713,11 +710,12 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
|
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
|
||||||
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||||
height, width = image_height, image_width
|
height, width = image_height, image_width
|
||||||
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
|
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
|
||||||
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
|
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
|
||||||
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
|
image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
|
||||||
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
|
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
|
||||||
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
|
image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
|
||||||
|
image_mask_latents = self._pack_latents(image_mask_latents)
|
||||||
|
|
||||||
prompt_image = image
|
prompt_image = image
|
||||||
if image.size != (image_width, image_height):
|
if image.size != (image_width, image_height):
|
||||||
@ -822,6 +820,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||||
)
|
)
|
||||||
morph, first_step = False, 0
|
morph, first_step = False, 0
|
||||||
|
lanpaint_proc = None
|
||||||
if image_mask_latents is not None:
|
if image_mask_latents is not None:
|
||||||
randn = torch.randn_like(original_image_latents)
|
randn = torch.randn_like(original_image_latents)
|
||||||
if denoising_strength < 1.:
|
if denoising_strength < 1.:
|
||||||
@ -833,7 +832,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
timesteps = timesteps[first_step:]
|
timesteps = timesteps[first_step:]
|
||||||
self.scheduler.timesteps = timesteps
|
self.scheduler.timesteps = timesteps
|
||||||
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
|
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
|
||||||
|
# from shared.inpainting.lanpaint import LanPaint
|
||||||
|
# lanpaint_proc = LanPaint()
|
||||||
# 6. Denoising loop
|
# 6. Denoising loop
|
||||||
self.scheduler.set_begin_index(0)
|
self.scheduler.set_begin_index(0)
|
||||||
updated_num_steps= len(timesteps)
|
updated_num_steps= len(timesteps)
|
||||||
@ -847,48 +847,52 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
offload.set_step_no_for_lora(self.transformer, first_step + i)
|
offload.set_step_no_for_lora(self.transformer, first_step + i)
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
self._current_timestep = t
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
|
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
|
||||||
latent_noise_factor = t/1000
|
latent_noise_factor = t/1000
|
||||||
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
|
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
|
||||||
|
|
||||||
self._current_timestep = t
|
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
||||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
|
||||||
|
|
||||||
latent_model_input = latents
|
latents_dtype = latents.dtype
|
||||||
if image_latents is not None:
|
|
||||||
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
|
||||||
|
|
||||||
if do_true_cfg and joint_pass:
|
# latent_model_input = latents
|
||||||
noise_pred, neg_noise_pred = self.transformer(
|
def denoise(latent_model_input, true_cfg_scale):
|
||||||
hidden_states=latent_model_input,
|
if image_latents is not None:
|
||||||
timestep=timestep / 1000,
|
latent_model_input = torch.cat([latents, image_latents], dim=1)
|
||||||
guidance=guidance,
|
do_true_cfg = true_cfg_scale > 1
|
||||||
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
|
if do_true_cfg and joint_pass:
|
||||||
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
|
noise_pred, neg_noise_pred = self.transformer(
|
||||||
img_shapes=img_shapes,
|
hidden_states=latent_model_input,
|
||||||
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
|
timestep=timestep / 1000,
|
||||||
attention_kwargs=self.attention_kwargs,
|
guidance=guidance, #!!!!
|
||||||
**kwargs
|
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
|
||||||
)
|
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
|
||||||
if noise_pred == None: return None
|
img_shapes=img_shapes,
|
||||||
noise_pred = noise_pred[:, : latents.size(1)]
|
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
|
||||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
attention_kwargs=self.attention_kwargs,
|
||||||
else:
|
**kwargs
|
||||||
noise_pred = self.transformer(
|
)
|
||||||
hidden_states=latent_model_input,
|
if noise_pred == None: return None, None
|
||||||
timestep=timestep / 1000,
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
guidance=guidance,
|
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||||
encoder_hidden_states_mask_list=[prompt_embeds_mask],
|
else:
|
||||||
encoder_hidden_states_list=[prompt_embeds],
|
neg_noise_pred = None
|
||||||
img_shapes=img_shapes,
|
noise_pred = self.transformer(
|
||||||
txt_seq_lens_list=[txt_seq_lens],
|
hidden_states=latent_model_input,
|
||||||
attention_kwargs=self.attention_kwargs,
|
timestep=timestep / 1000,
|
||||||
**kwargs
|
guidance=guidance,
|
||||||
)[0]
|
encoder_hidden_states_mask_list=[prompt_embeds_mask],
|
||||||
if noise_pred == None: return None
|
encoder_hidden_states_list=[prompt_embeds],
|
||||||
noise_pred = noise_pred[:, : latents.size(1)]
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens_list=[txt_seq_lens],
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
**kwargs
|
||||||
|
)[0]
|
||||||
|
if noise_pred == None: return None, None
|
||||||
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
|
|
||||||
if do_true_cfg:
|
if do_true_cfg:
|
||||||
neg_noise_pred = self.transformer(
|
neg_noise_pred = self.transformer(
|
||||||
@ -902,27 +906,43 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
attention_kwargs=self.attention_kwargs,
|
attention_kwargs=self.attention_kwargs,
|
||||||
**kwargs
|
**kwargs
|
||||||
)[0]
|
)[0]
|
||||||
if neg_noise_pred == None: return None
|
if neg_noise_pred == None: return None, None
|
||||||
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
|
||||||
|
return noise_pred, neg_noise_pred
|
||||||
|
def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
|
||||||
|
if do_true_cfg:
|
||||||
|
comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
|
||||||
|
if comb_pred == None: return None
|
||||||
|
|
||||||
if do_true_cfg:
|
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||||
if comb_pred == None: return None
|
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||||
|
|
||||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
return noise_pred
|
||||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
|
||||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
|
||||||
neg_noise_pred = None
|
if lanpaint_proc is not None and i<=3:
|
||||||
|
latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
|
||||||
|
if latents is None: return None
|
||||||
|
|
||||||
|
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
|
||||||
|
if noise_pred == None: return None
|
||||||
|
noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
|
||||||
|
neg_noise_pred = None
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
noise_pred = None
|
||||||
|
|
||||||
if image_mask_latents is not None:
|
if image_mask_latents is not None:
|
||||||
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
|
if lanpaint_proc is not None:
|
||||||
latent_noise_factor = next_t / 1000
|
latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents
|
||||||
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
else:
|
||||||
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
|
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
|
||||||
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
|
latent_noise_factor = next_t / 1000
|
||||||
noisy_image = None
|
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
||||||
|
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
|
||||||
|
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
|
||||||
|
noisy_image = None
|
||||||
|
|
||||||
if latents.dtype != latents_dtype:
|
if latents.dtype != latents_dtype:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
|
|||||||
@ -32,9 +32,10 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|||||||
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
|
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
|
||||||
from shared.utils.vace_preprocessor import VaceVideoProcessor
|
from shared.utils.vace_preprocessor import VaceVideoProcessor
|
||||||
from shared.utils.basic_flowmatch import FlowMatchScheduler
|
from shared.utils.basic_flowmatch import FlowMatchScheduler
|
||||||
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
|
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
|
||||||
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
|
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
|
||||||
from mmgp import safetensors2
|
from mmgp import safetensors2
|
||||||
|
from shared.utils.audio_video import save_video
|
||||||
|
|
||||||
def optimized_scale(positive_flat, negative_flat):
|
def optimized_scale(positive_flat, negative_flat):
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ class WanAny2V:
|
|||||||
shard_fn= None)
|
shard_fn= None)
|
||||||
|
|
||||||
# base_model_type = "i2v2_2"
|
# base_model_type = "i2v2_2"
|
||||||
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]:
|
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]:
|
||||||
self.clip = CLIPModel(
|
self.clip = CLIPModel(
|
||||||
dtype=config.clip_dtype,
|
dtype=config.clip_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -102,7 +103,7 @@ class WanAny2V:
|
|||||||
tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
|
tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
|
||||||
|
|
||||||
|
|
||||||
if base_model_type in ["ti2v_2_2"]:
|
if base_model_type in ["ti2v_2_2", "lucy_edit"]:
|
||||||
self.vae_stride = (4, 16, 16)
|
self.vae_stride = (4, 16, 16)
|
||||||
vae_checkpoint = "Wan2.2_VAE.safetensors"
|
vae_checkpoint = "Wan2.2_VAE.safetensors"
|
||||||
vae = Wan2_2_VAE
|
vae = Wan2_2_VAE
|
||||||
@ -190,19 +191,14 @@ class WanAny2V:
|
|||||||
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
|
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
if self.model.config.get("vace_in_dim", None) != None:
|
if hasattr(self.model, "vace_blocks"):
|
||||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
|
||||||
min_area=480*832,
|
|
||||||
max_area=480*832,
|
|
||||||
min_fps=config.sample_fps,
|
|
||||||
max_fps=config.sample_fps,
|
|
||||||
zero_start=True,
|
|
||||||
seq_len=32760,
|
|
||||||
keep_last=True)
|
|
||||||
|
|
||||||
self.adapt_vace_model(self.model)
|
self.adapt_vace_model(self.model)
|
||||||
if self.model2 is not None: self.adapt_vace_model(self.model2)
|
if self.model2 is not None: self.adapt_vace_model(self.model2)
|
||||||
|
|
||||||
|
if hasattr(self.model, "face_adapter"):
|
||||||
|
self.adapt_animate_model(self.model)
|
||||||
|
if self.model2 is not None: self.adapt_animate_model(self.model2)
|
||||||
|
|
||||||
self.num_timesteps = 1000
|
self.num_timesteps = 1000
|
||||||
self.use_timestep_transform = True
|
self.use_timestep_transform = True
|
||||||
|
|
||||||
@ -277,51 +273,6 @@ class WanAny2V:
|
|||||||
def vace_latent(self, z, m):
|
def vace_latent(self, z, m):
|
||||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||||
|
|
||||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
|
|
||||||
from shared.utils.utils import save_image
|
|
||||||
ref_width, ref_height = ref_img.size
|
|
||||||
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
|
||||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
|
||||||
canvas = torch.zeros_like(ref_img) if return_mask else None
|
|
||||||
else:
|
|
||||||
if outpainting_dims != None:
|
|
||||||
final_height, final_width = image_size
|
|
||||||
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
|
||||||
else:
|
|
||||||
canvas_height, canvas_width = image_size
|
|
||||||
if full_frame:
|
|
||||||
new_height = canvas_height
|
|
||||||
new_width = canvas_width
|
|
||||||
top = left = 0
|
|
||||||
else:
|
|
||||||
# if fill_max and (canvas_height - new_height) < 16:
|
|
||||||
# new_height = canvas_height
|
|
||||||
# if fill_max and (canvas_width - new_width) < 16:
|
|
||||||
# new_width = canvas_width
|
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
|
||||||
new_height = int(ref_height * scale)
|
|
||||||
new_width = int(ref_width * scale)
|
|
||||||
top = (canvas_height - new_height) // 2
|
|
||||||
left = (canvas_width - new_width) // 2
|
|
||||||
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
|
||||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
|
||||||
if outpainting_dims != None:
|
|
||||||
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
|
||||||
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
|
|
||||||
else:
|
|
||||||
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
|
||||||
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
|
|
||||||
ref_img = canvas
|
|
||||||
canvas = None
|
|
||||||
if return_mask:
|
|
||||||
if outpainting_dims != None:
|
|
||||||
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
|
|
||||||
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
|
|
||||||
else:
|
|
||||||
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
|
|
||||||
canvas[:, :, top:top + new_height, left:left + new_width] = 0
|
|
||||||
canvas = canvas.to(device)
|
|
||||||
return ref_img.to(device), canvas
|
|
||||||
|
|
||||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||||
image_sizes = []
|
image_sizes = []
|
||||||
@ -375,7 +326,7 @@ class WanAny2V:
|
|||||||
for k, frame in enumerate(inject_frames):
|
for k, frame in enumerate(inject_frames):
|
||||||
if frame != None:
|
if frame != None:
|
||||||
pos = prepend_count + k
|
pos = prepend_count + k
|
||||||
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||||
|
|
||||||
|
|
||||||
self.background_mask = None
|
self.background_mask = None
|
||||||
@ -386,9 +337,9 @@ class WanAny2V:
|
|||||||
if ref_img is not None and not torch.is_tensor(ref_img):
|
if ref_img is not None and not torch.is_tensor(ref_img):
|
||||||
if j==0 and any_background_ref:
|
if j==0 and any_background_ref:
|
||||||
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
|
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
|
||||||
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||||
else:
|
else:
|
||||||
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
|
src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
|
||||||
if self.background_mask != None:
|
if self.background_mask != None:
|
||||||
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
||||||
return src_video, src_mask, src_ref_images
|
return src_video, src_mask, src_ref_images
|
||||||
@ -402,12 +353,26 @@ class WanAny2V:
|
|||||||
|
|
||||||
return torch.cat(ref_vae_latents, dim=1)
|
return torch.cat(ref_vae_latents, dim=1)
|
||||||
|
|
||||||
|
def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"):
|
||||||
|
if mask_pixel_values is None:
|
||||||
|
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||||
|
else:
|
||||||
|
msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest')
|
||||||
|
|
||||||
|
if nb_frames_unchanged >0:
|
||||||
|
msk[:, :nb_frames_unchanged] = 1
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
|
msk = msk.transpose(1,2)[0]
|
||||||
|
return msk
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
input_frames= None,
|
input_frames= None,
|
||||||
input_masks = None,
|
input_masks = None,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
|
input_ref_masks = None,
|
||||||
|
input_faces = None,
|
||||||
input_video = None,
|
input_video = None,
|
||||||
image_start = None,
|
image_start = None,
|
||||||
image_end = None,
|
image_end = None,
|
||||||
@ -541,14 +506,18 @@ class WanAny2V:
|
|||||||
infinitetalk = model_type in ["infinitetalk"]
|
infinitetalk = model_type in ["infinitetalk"]
|
||||||
standin = model_type in ["standin", "vace_standin_14B"]
|
standin = model_type in ["standin", "vace_standin_14B"]
|
||||||
recam = model_type in ["recam_1.3B"]
|
recam = model_type in ["recam_1.3B"]
|
||||||
ti2v = model_type in ["ti2v_2_2"]
|
ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
|
||||||
|
lucy_edit= model_type in ["lucy_edit"]
|
||||||
|
animate= model_type in ["animate"]
|
||||||
start_step_no = 0
|
start_step_no = 0
|
||||||
ref_images_count = 0
|
ref_images_count = 0
|
||||||
trim_frames = 0
|
trim_frames = 0
|
||||||
extended_overlapped_latents = None
|
extended_overlapped_latents = clip_image_start = clip_image_end = None
|
||||||
no_noise_latents_injection = infinitetalk
|
no_noise_latents_injection = infinitetalk
|
||||||
timestep_injection = False
|
timestep_injection = False
|
||||||
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
|
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
|
||||||
|
extended_input_dim = 0
|
||||||
|
ref_images_before = False
|
||||||
# image2video
|
# image2video
|
||||||
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
|
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
|
||||||
any_end_frame = False
|
any_end_frame = False
|
||||||
@ -598,17 +567,7 @@ class WanAny2V:
|
|||||||
|
|
||||||
if image_end is not None:
|
if image_end is not None:
|
||||||
img_end_frame = image_end.unsqueeze(1).to(self.device)
|
img_end_frame = image_end.unsqueeze(1).to(self.device)
|
||||||
|
clip_image_start, clip_image_end = image_start, image_end
|
||||||
if hasattr(self, "clip"):
|
|
||||||
clip_image_size = self.clip.model.image_size
|
|
||||||
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
|
|
||||||
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start
|
|
||||||
if model_type == "flf2v_720p":
|
|
||||||
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
|
|
||||||
else:
|
|
||||||
clip_context = self.clip.visual([image_start[:, None, :, :]])
|
|
||||||
else:
|
|
||||||
clip_context = None
|
|
||||||
|
|
||||||
if any_end_frame:
|
if any_end_frame:
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
@ -647,21 +606,62 @@ class WanAny2V:
|
|||||||
if infinitetalk:
|
if infinitetalk:
|
||||||
lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
|
lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
|
||||||
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
|
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
|
||||||
# if control_pre_frames_count != pre_frames_count:
|
|
||||||
|
|
||||||
lat_y = input_video = None
|
lat_y = input_video = None
|
||||||
kwargs.update({ 'y': y})
|
kwargs.update({ 'y': y})
|
||||||
if not clip_context is None:
|
|
||||||
kwargs.update({'clip_fea': clip_context})
|
|
||||||
|
|
||||||
# Recam Master
|
# Animate
|
||||||
if recam:
|
if animate:
|
||||||
target_camera = model_mode
|
pose_pixels = input_frames * input_masks
|
||||||
height,width = input_frames.shape[-2:]
|
input_masks = 1. - input_masks
|
||||||
input_frames = input_frames.to(dtype=self.dtype , device=self.device)
|
pose_pixels -= input_masks
|
||||||
source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
|
save_video(pose_pixels, "pose.mp4")
|
||||||
|
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
|
||||||
|
input_frames = input_frames * input_masks
|
||||||
|
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
|
||||||
|
if prefix_frames_count > 0:
|
||||||
|
input_frames[:, :prefix_frames_count] = input_video
|
||||||
|
input_masks[:, :prefix_frames_count] = 1
|
||||||
|
save_video(input_frames, "input_frames.mp4")
|
||||||
|
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
|
||||||
|
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
|
||||||
|
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
|
||||||
|
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
|
||||||
|
msk = torch.concat([msk_ref, msk_control], dim=1)
|
||||||
|
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
|
||||||
|
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
|
||||||
|
y = torch.concat([msk, lat_y])
|
||||||
|
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
|
||||||
|
lat_y = msk = msk_control = msk_ref = pose_pixels = None
|
||||||
|
ref_images_before = True
|
||||||
|
ref_images_count = 1
|
||||||
|
lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
|
||||||
|
|
||||||
|
# Clip image
|
||||||
|
if hasattr(self, "clip") and clip_image_start is not None:
|
||||||
|
clip_image_size = self.clip.model.image_size
|
||||||
|
clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
|
||||||
|
clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
|
||||||
|
if model_type == "flf2v_720p":
|
||||||
|
clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
|
||||||
|
else:
|
||||||
|
clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
|
||||||
|
clip_image_start = clip_image_end = None
|
||||||
|
kwargs.update({'clip_fea': clip_context})
|
||||||
|
|
||||||
|
# Recam Master & Lucy Edit
|
||||||
|
if recam or lucy_edit:
|
||||||
|
frame_num, height,width = input_frames.shape[-3:]
|
||||||
|
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
|
||||||
|
frame_num = (lat_frames -1) * self.vae_stride[0] + 1
|
||||||
|
input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
|
||||||
|
extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
|
||||||
|
extended_input_dim = 2 if recam else 1
|
||||||
del input_frames
|
del input_frames
|
||||||
|
|
||||||
|
if recam:
|
||||||
# Process target camera (recammaster)
|
# Process target camera (recammaster)
|
||||||
|
target_camera = model_mode
|
||||||
from shared.utils.cammmaster_tools import get_camera_embedding
|
from shared.utils.cammmaster_tools import get_camera_embedding
|
||||||
cam_emb = get_camera_embedding(target_camera)
|
cam_emb = get_camera_embedding(target_camera)
|
||||||
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
|
||||||
@ -715,6 +715,8 @@ class WanAny2V:
|
|||||||
height, width = input_video.shape[-2:]
|
height, width = input_video.shape[-2:]
|
||||||
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
|
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
|
||||||
timestep_injection = True
|
timestep_injection = True
|
||||||
|
if extended_input_dim > 0:
|
||||||
|
extended_latents[:, :, :source_latents.shape[2]] = source_latents
|
||||||
|
|
||||||
# Vace
|
# Vace
|
||||||
if vace :
|
if vace :
|
||||||
@ -722,6 +724,7 @@ class WanAny2V:
|
|||||||
input_frames = [u.to(self.device) for u in input_frames]
|
input_frames = [u.to(self.device) for u in input_frames]
|
||||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
||||||
input_masks = [u.to(self.device) for u in input_masks]
|
input_masks = [u.to(self.device) for u in input_masks]
|
||||||
|
ref_images_before = True
|
||||||
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
||||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
@ -771,9 +774,9 @@ class WanAny2V:
|
|||||||
|
|
||||||
expand_shape = [batch_size] + [-1] * len(target_shape)
|
expand_shape = [batch_size] + [-1] * len(target_shape)
|
||||||
# Ropes
|
# Ropes
|
||||||
if target_camera != None:
|
if extended_input_dim>=2:
|
||||||
shape = list(target_shape[1:])
|
shape = list(target_shape[1:])
|
||||||
shape[0] *= 2
|
shape[extended_input_dim-2] *= 2
|
||||||
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
||||||
else:
|
else:
|
||||||
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
@ -901,8 +904,8 @@ class WanAny2V:
|
|||||||
for zz in z:
|
for zz in z:
|
||||||
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
|
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
|
||||||
|
|
||||||
if target_camera != None:
|
if extended_input_dim > 0:
|
||||||
latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2)
|
latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
|
||||||
else:
|
else:
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
|
|
||||||
@ -1030,7 +1033,7 @@ class WanAny2V:
|
|||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
latents_preview = latents
|
latents_preview = latents
|
||||||
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
|
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
|
||||||
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
|
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
|
||||||
if image_outputs: latents_preview= latents_preview[:, :,:1]
|
if image_outputs: latents_preview= latents_preview[:, :,:1]
|
||||||
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
|
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
|
||||||
@ -1041,7 +1044,7 @@ class WanAny2V:
|
|||||||
if timestep_injection:
|
if timestep_injection:
|
||||||
latents[:, :, :source_latents.shape[2]] = source_latents
|
latents[:, :, :source_latents.shape[2]] = source_latents
|
||||||
|
|
||||||
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
|
if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
|
||||||
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
|
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
|
||||||
if return_latent_slice != None:
|
if return_latent_slice != None:
|
||||||
latent_slice = latents[:, :, return_latent_slice].clone()
|
latent_slice = latents[:, :, return_latent_slice].clone()
|
||||||
@ -1078,4 +1081,12 @@ class WanAny2V:
|
|||||||
delattr(model, "vace_blocks")
|
delattr(model, "vace_blocks")
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_animate_model(self, model):
|
||||||
|
modules_dict= { k: m for k, m in model.named_modules()}
|
||||||
|
for animate_layer in range(8):
|
||||||
|
module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
|
||||||
|
model_layer = animate_layer * 5
|
||||||
|
target = modules_dict[f"blocks.{model_layer}"]
|
||||||
|
setattr(target, "face_adapter_fuser_blocks", module )
|
||||||
|
delattr(model, "face_adapter")
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,9 @@ from mmgp.offload import get_cache, clear_caches
|
|||||||
from shared.attention import pay_attention
|
from shared.attention import pay_attention
|
||||||
from torch.backends.cuda import sdp_kernel
|
from torch.backends.cuda import sdp_kernel
|
||||||
from ..multitalk.multitalk_utils import get_attn_map_with_target
|
from ..multitalk.multitalk_utils import get_attn_map_with_target
|
||||||
|
from ..animate.motion_encoder import Generator
|
||||||
|
from ..animate.face_blocks import FaceAdapter, FaceEncoder
|
||||||
|
from ..animate.model_animate import after_patch_embedding
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
@ -499,6 +502,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
multitalk_masks=None,
|
multitalk_masks=None,
|
||||||
ref_images_count=0,
|
ref_images_count=0,
|
||||||
standin_phase=-1,
|
standin_phase=-1,
|
||||||
|
motion_vec = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -616,6 +620,10 @@ class WanAttentionBlock(nn.Module):
|
|||||||
x.add_(hint)
|
x.add_(hint)
|
||||||
else:
|
else:
|
||||||
x.add_(hint, alpha= scale)
|
x.add_(hint, alpha= scale)
|
||||||
|
|
||||||
|
if motion_vec is not None and self.block_no % 5 == 0:
|
||||||
|
x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class AudioProjModel(ModelMixin, ConfigMixin):
|
class AudioProjModel(ModelMixin, ConfigMixin):
|
||||||
@ -898,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
norm_input_visual=True,
|
norm_input_visual=True,
|
||||||
norm_output_audio=True,
|
norm_output_audio=True,
|
||||||
standin= False,
|
standin= False,
|
||||||
|
motion_encoder_dim=0,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -922,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
self.flag_causal_attention = False
|
self.flag_causal_attention = False
|
||||||
self.block_mask = None
|
self.block_mask = None
|
||||||
self.inject_sample_info = inject_sample_info
|
self.inject_sample_info = inject_sample_info
|
||||||
|
self.motion_encoder_dim = motion_encoder_dim
|
||||||
self.norm_output_audio = norm_output_audio
|
self.norm_output_audio = norm_output_audio
|
||||||
self.audio_window = audio_window
|
self.audio_window = audio_window
|
||||||
self.intermediate_dim = intermediate_dim
|
self.intermediate_dim = intermediate_dim
|
||||||
self.vae_scale = vae_scale
|
self.vae_scale = vae_scale
|
||||||
|
|
||||||
multitalk = multitalk_output_dim > 0
|
multitalk = multitalk_output_dim > 0
|
||||||
self.multitalk = multitalk
|
self.multitalk = multitalk
|
||||||
|
animate = motion_encoder_dim > 0
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
self.patch_embedding = nn.Conv3d(
|
self.patch_embedding = nn.Conv3d(
|
||||||
@ -1027,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
|
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
|
||||||
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
|
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
|
||||||
|
|
||||||
|
if animate:
|
||||||
|
self.pose_patch_embedding = nn.Conv3d(
|
||||||
|
16, dim, kernel_size=patch_size, stride=patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
||||||
|
self.face_adapter = FaceAdapter(
|
||||||
|
heads_num=self.num_heads,
|
||||||
|
hidden_dim=self.dim,
|
||||||
|
num_adapter_layers=self.num_layers // 5,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.face_encoder = FaceEncoder(
|
||||||
|
in_dim=motion_encoder_dim,
|
||||||
|
hidden_dim=self.dim,
|
||||||
|
num_heads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
|
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
|
||||||
layer_list = [self.head, self.head.head, self.patch_embedding]
|
layer_list = [self.head, self.head.head, self.patch_embedding]
|
||||||
target_dype= dtype
|
target_dype= dtype
|
||||||
@ -1208,6 +1237,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
ref_images_count = 0,
|
ref_images_count = 0,
|
||||||
standin_freqs = None,
|
standin_freqs = None,
|
||||||
standin_ref = None,
|
standin_ref = None,
|
||||||
|
pose_latents=None,
|
||||||
|
face_pixel_values=None,
|
||||||
|
|
||||||
):
|
):
|
||||||
# patch_dtype = self.patch_embedding.weight.dtype
|
# patch_dtype = self.patch_embedding.weight.dtype
|
||||||
modulation_dtype = self.time_projection[1].weight.dtype
|
modulation_dtype = self.time_projection[1].weight.dtype
|
||||||
@ -1240,9 +1272,18 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
|
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
|
||||||
x = torch.cat([x, y], dim=1)
|
x = torch.cat([x, y], dim=1)
|
||||||
# embeddings
|
# embeddings
|
||||||
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
|
|
||||||
x = self.patch_embedding(x).to(modulation_dtype)
|
x = self.patch_embedding(x).to(modulation_dtype)
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
|
x_list[i] = x
|
||||||
|
y = None
|
||||||
|
|
||||||
|
motion_vec_list = []
|
||||||
|
for i, x in enumerate(x_list):
|
||||||
|
# animate embeddings
|
||||||
|
motion_vec = None
|
||||||
|
if pose_latents is not None:
|
||||||
|
x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
|
||||||
|
motion_vec_list.append(motion_vec)
|
||||||
if chipmunk:
|
if chipmunk:
|
||||||
x = x.unsqueeze(-1)
|
x = x.unsqueeze(-1)
|
||||||
x_og_shape = x.shape
|
x_og_shape = x.shape
|
||||||
@ -1250,7 +1291,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
x_list[i] = x
|
x_list[i] = x
|
||||||
x, y = None, None
|
x = None
|
||||||
|
|
||||||
|
|
||||||
block_mask = None
|
block_mask = None
|
||||||
@ -1450,9 +1491,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
continue
|
continue
|
||||||
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
|
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
|
||||||
else:
|
else:
|
||||||
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
|
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
|
||||||
if should_calc:
|
if should_calc:
|
||||||
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
|
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
|
||||||
del x
|
del x
|
||||||
context = hints = audio_embedding = None
|
context = hints = audio_embedding = None
|
||||||
|
|
||||||
|
|||||||
@ -3,10 +3,10 @@ import numpy as np
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
def test_class_i2v(base_model_type):
|
def test_class_i2v(base_model_type):
|
||||||
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
|
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
|
||||||
|
|
||||||
def text_oneframe_overlap(base_model_type):
|
def text_oneframe_overlap(base_model_type):
|
||||||
return test_class_i2v(base_model_type) and not test_multitalk(base_model_type)
|
return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
|
||||||
|
|
||||||
def test_class_1_3B(base_model_type):
|
def test_class_1_3B(base_model_type):
|
||||||
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
|
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
|
||||||
@ -17,6 +17,8 @@ def test_multitalk(base_model_type):
|
|||||||
def test_standin(base_model_type):
|
def test_standin(base_model_type):
|
||||||
return base_model_type in ["standin", "vace_standin_14B"]
|
return base_model_type in ["standin", "vace_standin_14B"]
|
||||||
|
|
||||||
|
def test_wan_5B(base_model_type):
|
||||||
|
return base_model_type in ["ti2v_2_2", "lucy_edit"]
|
||||||
class family_handler():
|
class family_handler():
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -36,7 +38,7 @@ class family_handler():
|
|||||||
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
|
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
|
||||||
elif base_model_type in ["i2v_2_2"]:
|
elif base_model_type in ["i2v_2_2"]:
|
||||||
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
|
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
|
||||||
elif base_model_type in ["ti2v_2_2"]:
|
elif test_wan_5B(base_model_type):
|
||||||
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
|
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
|
||||||
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
|
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
|
||||||
else: # i2v
|
else: # i2v
|
||||||
@ -83,11 +85,13 @@ class family_handler():
|
|||||||
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
|
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
|
||||||
extra_model_def["vace_class"] = vace_class
|
extra_model_def["vace_class"] = vace_class
|
||||||
|
|
||||||
if test_multitalk(base_model_type):
|
if base_model_type in ["animate"]:
|
||||||
|
fps = 30
|
||||||
|
elif test_multitalk(base_model_type):
|
||||||
fps = 25
|
fps = 25
|
||||||
elif base_model_type in ["fantasy"]:
|
elif base_model_type in ["fantasy"]:
|
||||||
fps = 23
|
fps = 23
|
||||||
elif base_model_type in ["ti2v_2_2"]:
|
elif test_wan_5B(base_model_type):
|
||||||
fps = 24
|
fps = 24
|
||||||
else:
|
else:
|
||||||
fps = 16
|
fps = 16
|
||||||
@ -100,14 +104,14 @@ class family_handler():
|
|||||||
extra_model_def.update({
|
extra_model_def.update({
|
||||||
"frames_minimum" : frames_minimum,
|
"frames_minimum" : frames_minimum,
|
||||||
"frames_steps" : frames_steps,
|
"frames_steps" : frames_steps,
|
||||||
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
|
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2",
|
||||||
"multiple_submodels" : multiple_submodels,
|
"multiple_submodels" : multiple_submodels,
|
||||||
"guidance_max_phases" : 3,
|
"guidance_max_phases" : 3,
|
||||||
"skip_layer_guidance" : True,
|
"skip_layer_guidance" : True,
|
||||||
"cfg_zero" : True,
|
"cfg_zero" : True,
|
||||||
"cfg_star" : True,
|
"cfg_star" : True,
|
||||||
"adaptive_projected_guidance" : True,
|
"adaptive_projected_guidance" : True,
|
||||||
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
|
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
|
||||||
"mag_cache" : True,
|
"mag_cache" : True,
|
||||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||||
"convert_image_guide_to_video" : True,
|
"convert_image_guide_to_video" : True,
|
||||||
@ -146,6 +150,34 @@ class family_handler():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# extra_model_def["at_least_one_image_ref_needed"] = True
|
# extra_model_def["at_least_one_image_ref_needed"] = True
|
||||||
|
if base_model_type in ["lucy_edit"]:
|
||||||
|
extra_model_def["keep_frames_video_guide_not_supported"] = True
|
||||||
|
extra_model_def["guide_preprocessing"] = {
|
||||||
|
"selection": ["UV"],
|
||||||
|
"labels" : { "UV": "Control Video"},
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if base_model_type in ["animate"]:
|
||||||
|
extra_model_def["guide_custom_choices"] = {
|
||||||
|
"choices":[
|
||||||
|
("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"),
|
||||||
|
("Replace Person in Control Video Person in Reference Image", "PVBAI"),
|
||||||
|
],
|
||||||
|
"default": "KI",
|
||||||
|
"letters_filter": "PVBXAKI",
|
||||||
|
"label": "Type of Process",
|
||||||
|
"show_label" : False,
|
||||||
|
}
|
||||||
|
extra_model_def["video_guide_outpainting"] = [0,1]
|
||||||
|
extra_model_def["keep_frames_video_guide_not_supported"] = True
|
||||||
|
extra_model_def["extract_guide_from_window_start"] = True
|
||||||
|
extra_model_def["forced_guide_mask_inputs"] = True
|
||||||
|
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
|
||||||
|
extra_model_def["background_ref_outpainted"] = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if vace_class:
|
if vace_class:
|
||||||
extra_model_def["guide_preprocessing"] = {
|
extra_model_def["guide_preprocessing"] = {
|
||||||
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
|
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
|
||||||
@ -157,16 +189,19 @@ class family_handler():
|
|||||||
|
|
||||||
extra_model_def["image_ref_choices"] = {
|
extra_model_def["image_ref_choices"] = {
|
||||||
"choices": [("None", ""),
|
"choices": [("None", ""),
|
||||||
("Inject only People / Objects", "I"),
|
("People / Objects", "I"),
|
||||||
("Inject Landscape and then People / Objects", "KI"),
|
("Landscape followed by People / Objects (if any)", "KI"),
|
||||||
("Inject Frames and then People / Objects", "FI"),
|
("Positioned Frames followed by People / Objects (if any)", "FI"),
|
||||||
],
|
],
|
||||||
"letters_filter": "KFI",
|
"letters_filter": "KFI",
|
||||||
}
|
}
|
||||||
|
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
extra_model_def["lock_image_refs_ratios"] = True
|
||||||
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames"
|
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
|
||||||
extra_model_def["video_guide_outpainting"] = [0,1]
|
extra_model_def["video_guide_outpainting"] = [0,1]
|
||||||
|
extra_model_def["pad_guide_video"] = True
|
||||||
|
extra_model_def["guide_inpaint_color"] = 127.5
|
||||||
|
extra_model_def["forced_guide_mask_inputs"] = True
|
||||||
|
|
||||||
if base_model_type in ["standin"]:
|
if base_model_type in ["standin"]:
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
extra_model_def["lock_image_refs_ratios"] = True
|
||||||
@ -209,10 +244,12 @@ class family_handler():
|
|||||||
"visible" : False,
|
"visible" : False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if vace_class or base_model_type in ["infinitetalk"]:
|
if vace_class or base_model_type in ["infinitetalk", "animate"]:
|
||||||
image_prompt_types_allowed = "TVL"
|
image_prompt_types_allowed = "TVL"
|
||||||
elif base_model_type in ["ti2v_2_2"]:
|
elif base_model_type in ["ti2v_2_2"]:
|
||||||
image_prompt_types_allowed = "TSVL"
|
image_prompt_types_allowed = "TSVL"
|
||||||
|
elif base_model_type in ["lucy_edit"]:
|
||||||
|
image_prompt_types_allowed = "TVL"
|
||||||
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
|
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
|
||||||
image_prompt_types_allowed = "SVL"
|
image_prompt_types_allowed = "SVL"
|
||||||
elif i2v:
|
elif i2v:
|
||||||
@ -234,8 +271,8 @@ class family_handler():
|
|||||||
def query_supported_types():
|
def query_supported_types():
|
||||||
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
|
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
|
||||||
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
|
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
|
||||||
"recam_1.3B",
|
"recam_1.3B", "animate",
|
||||||
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
|
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -265,11 +302,12 @@ class family_handler():
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_vae_block_size(base_model_type):
|
def get_vae_block_size(base_model_type):
|
||||||
return 32 if base_model_type == "ti2v_2_2" else 16
|
return 32 if test_wan_5B(base_model_type) else 16
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_rgb_factors(base_model_type ):
|
def get_rgb_factors(base_model_type ):
|
||||||
from shared.RGB_factors import get_rgb_factors
|
from shared.RGB_factors import get_rgb_factors
|
||||||
|
if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
|
||||||
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
|
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
|
||||||
return latent_rgb_factors, latent_rgb_factors_bias
|
return latent_rgb_factors, latent_rgb_factors_bias
|
||||||
|
|
||||||
@ -283,7 +321,7 @@ class family_handler():
|
|||||||
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
|
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
if base_model_type == "ti2v_2_2":
|
if test_wan_5B(base_model_type):
|
||||||
download_def += [ {
|
download_def += [ {
|
||||||
"repoId" : "DeepBeepMeep/Wan2.2",
|
"repoId" : "DeepBeepMeep/Wan2.2",
|
||||||
"sourceFolderList" : [""],
|
"sourceFolderList" : [""],
|
||||||
@ -377,8 +415,8 @@ class family_handler():
|
|||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"sample_solver": "unipc",
|
"sample_solver": "unipc",
|
||||||
})
|
})
|
||||||
if test_class_i2v(base_model_type):
|
if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
|
||||||
ui_defaults["image_prompt_type"] = "S"
|
ui_defaults["image_prompt_type"] = "S"
|
||||||
|
|
||||||
if base_model_type in ["fantasy"]:
|
if base_model_type in ["fantasy"]:
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
@ -434,10 +472,15 @@ class family_handler():
|
|||||||
"image_prompt_type": "T",
|
"image_prompt_type": "T",
|
||||||
})
|
})
|
||||||
|
|
||||||
if base_model_type in ["recam_1.3B"]:
|
if base_model_type in ["recam_1.3B", "lucy_edit"]:
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"video_prompt_type": "UV",
|
"video_prompt_type": "UV",
|
||||||
})
|
})
|
||||||
|
elif base_model_type in ["animate"]:
|
||||||
|
ui_defaults.update({
|
||||||
|
"video_prompt_type": "PVBXAKI",
|
||||||
|
"mask_expand": 20,
|
||||||
|
})
|
||||||
|
|
||||||
if text_oneframe_overlap(base_model_type):
|
if text_oneframe_overlap(base_model_type):
|
||||||
ui_defaults["sliding_window_overlap"] = 1
|
ui_defaults["sliding_window_overlap"] = 1
|
||||||
|
|||||||
342
shared/convert/convert_diffusers_to_flux.py
Normal file
342
shared/convert/convert_diffusers_to_flux.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Convert a Flux model from Diffusers (folder or single-file) into the original
|
||||||
|
single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
|
||||||
|
|
||||||
|
Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
|
||||||
|
Output : /path/to/flux1-your-model.safetensors (transformer only)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
|
||||||
|
python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
|
||||||
|
# optional quantization:
|
||||||
|
# --fp8 (float8_e4m3fn, simple)
|
||||||
|
# --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
import safetensors.torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("diffusers_path", type=str,
|
||||||
|
help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
|
||||||
|
ap.add_argument("output_path", type=str,
|
||||||
|
help="Output .safetensors path for the Flux transformer.")
|
||||||
|
ap.add_argument("--fp8", action="store_true",
|
||||||
|
help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
|
||||||
|
ap.add_argument("--fp8-scaled", action="store_true",
|
||||||
|
help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
|
||||||
|
return ap.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
|
||||||
|
DIFFUSERS_MAP = {
|
||||||
|
# global embeds
|
||||||
|
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
|
||||||
|
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
|
||||||
|
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
|
||||||
|
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
|
||||||
|
|
||||||
|
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
|
||||||
|
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
|
||||||
|
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
|
||||||
|
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
|
||||||
|
|
||||||
|
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
|
||||||
|
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
|
||||||
|
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
|
||||||
|
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
|
||||||
|
|
||||||
|
"txt_in.weight": ["context_embedder.weight"],
|
||||||
|
"txt_in.bias": ["context_embedder.bias"],
|
||||||
|
"img_in.weight": ["x_embedder.weight"],
|
||||||
|
"img_in.bias": ["x_embedder.bias"],
|
||||||
|
|
||||||
|
# dual-stream (image/text) blocks
|
||||||
|
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
|
||||||
|
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
|
||||||
|
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
|
||||||
|
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
|
||||||
|
|
||||||
|
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
|
||||||
|
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
|
||||||
|
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
|
||||||
|
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
|
||||||
|
|
||||||
|
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
|
||||||
|
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
|
||||||
|
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
|
||||||
|
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
|
||||||
|
|
||||||
|
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
|
||||||
|
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
|
||||||
|
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
|
||||||
|
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
|
||||||
|
|
||||||
|
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
|
||||||
|
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
|
||||||
|
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
|
||||||
|
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
|
||||||
|
|
||||||
|
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
|
||||||
|
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
|
||||||
|
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
|
||||||
|
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
|
||||||
|
|
||||||
|
# single-stream blocks
|
||||||
|
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
|
||||||
|
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
|
||||||
|
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
|
||||||
|
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
|
||||||
|
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
|
||||||
|
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
|
||||||
|
"single_blocks.().linear2.weight": ["proj_out.weight"],
|
||||||
|
"single_blocks.().linear2.bias": ["proj_out.bias"],
|
||||||
|
|
||||||
|
# final
|
||||||
|
"final_layer.linear.weight": ["proj_out.weight"],
|
||||||
|
"final_layer.linear.bias": ["proj_out.bias"],
|
||||||
|
# these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
|
||||||
|
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
|
||||||
|
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusersSource:
|
||||||
|
"""
|
||||||
|
Uniform interface over:
|
||||||
|
1) Folder with index JSON + shards
|
||||||
|
2) Folder with exactly one .safetensors (no index)
|
||||||
|
3) Single .safetensors file
|
||||||
|
Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
|
||||||
|
"""
|
||||||
|
|
||||||
|
POSSIBLE_PREFIXES = ["", "model."] # try in this order
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
p = Path(path)
|
||||||
|
if p.is_dir():
|
||||||
|
# use 'transformer' subfolder if present
|
||||||
|
if (p / "transformer").is_dir():
|
||||||
|
p = p / "transformer"
|
||||||
|
self._init_from_dir(p)
|
||||||
|
elif p.is_file() and p.suffix == ".safetensors":
|
||||||
|
self._init_from_single_file(p)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"Invalid path: {p}")
|
||||||
|
|
||||||
|
# ---------- common helpers ----------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_prefix(k: str) -> str:
|
||||||
|
return k[6:] if k.startswith("model.") else k
|
||||||
|
|
||||||
|
def _resolve(self, want: str):
|
||||||
|
"""
|
||||||
|
Return the actual stored key matching `want` by trying known prefixes.
|
||||||
|
"""
|
||||||
|
for pref in self.POSSIBLE_PREFIXES:
|
||||||
|
k = pref + want
|
||||||
|
if k in self._all_keys:
|
||||||
|
return k
|
||||||
|
return None
|
||||||
|
|
||||||
|
def has(self, want: str) -> bool:
|
||||||
|
return self._resolve(want) is not None
|
||||||
|
|
||||||
|
def get(self, want: str) -> torch.Tensor:
|
||||||
|
real_key = self._resolve(want)
|
||||||
|
if real_key is None:
|
||||||
|
raise KeyError(f"Missing key: {want}")
|
||||||
|
return self._get_by_real_key(real_key).to("cpu")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_keys(self):
|
||||||
|
# keys without 'model.' prefix for scanning
|
||||||
|
return [self._strip_prefix(k) for k in self._all_keys]
|
||||||
|
|
||||||
|
# ---------- modes ----------
|
||||||
|
|
||||||
|
def _init_from_single_file(self, file_path: Path):
|
||||||
|
self._mode = "single"
|
||||||
|
self._file = file_path
|
||||||
|
self._handle = safe_open(file_path, framework="pt", device="cpu")
|
||||||
|
self._all_keys = list(self._handle.keys())
|
||||||
|
|
||||||
|
def _get_by_real_key(real_key: str):
|
||||||
|
return self._handle.get_tensor(real_key)
|
||||||
|
|
||||||
|
self._get_by_real_key = _get_by_real_key
|
||||||
|
|
||||||
|
def _init_from_dir(self, dpath: Path):
|
||||||
|
index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
|
||||||
|
if index_json.exists():
|
||||||
|
with open(index_json, "r", encoding="utf-8") as f:
|
||||||
|
index = json.load(f)
|
||||||
|
weight_map = index["weight_map"] # full mapping
|
||||||
|
self._mode = "sharded"
|
||||||
|
self._dpath = dpath
|
||||||
|
self._weight_map = {k: dpath / v for k, v in weight_map.items()}
|
||||||
|
self._all_keys = list(self._weight_map.keys())
|
||||||
|
self._open_handles = {}
|
||||||
|
|
||||||
|
def _get_by_real_key(real_key: str):
|
||||||
|
fpath = self._weight_map[real_key]
|
||||||
|
h = self._open_handles.get(fpath)
|
||||||
|
if h is None:
|
||||||
|
h = safe_open(fpath, framework="pt", device="cpu")
|
||||||
|
self._open_handles[fpath] = h
|
||||||
|
return h.get_tensor(real_key)
|
||||||
|
|
||||||
|
self._get_by_real_key = _get_by_real_key
|
||||||
|
return
|
||||||
|
|
||||||
|
# no index: try exactly one safetensors in folder
|
||||||
|
files = sorted(dpath.glob("*.safetensors"))
|
||||||
|
if len(files) != 1:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No index found and {dpath} does not contain exactly one .safetensors file."
|
||||||
|
)
|
||||||
|
self._init_from_single_file(files[0])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
src = DiffusersSource(Path(args.diffusers_path))
|
||||||
|
|
||||||
|
# Count blocks by scanning base keys (with any 'model.' prefix removed)
|
||||||
|
num_dual = 0
|
||||||
|
num_single = 0
|
||||||
|
for k in src.base_keys:
|
||||||
|
if k.startswith("transformer_blocks."):
|
||||||
|
try:
|
||||||
|
i = int(k.split(".")[1])
|
||||||
|
num_dual = max(num_dual, i + 1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif k.startswith("single_transformer_blocks."):
|
||||||
|
try:
|
||||||
|
i = int(k.split(".")[1])
|
||||||
|
num_single = max(num_single, i + 1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
|
||||||
|
|
||||||
|
# Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
|
||||||
|
def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
|
||||||
|
shift, scale = vec.chunk(2, dim=0)
|
||||||
|
return torch.cat([scale, shift], dim=0)
|
||||||
|
|
||||||
|
orig = {}
|
||||||
|
|
||||||
|
# Per-block (dual)
|
||||||
|
for b in range(num_dual):
|
||||||
|
prefix = f"transformer_blocks.{b}."
|
||||||
|
for okey, dvals in DIFFUSERS_MAP.items():
|
||||||
|
if not okey.startswith("double_blocks."):
|
||||||
|
continue
|
||||||
|
dkeys = [prefix + v for v in dvals]
|
||||||
|
if not all(src.has(k) for k in dkeys):
|
||||||
|
continue
|
||||||
|
if len(dkeys) == 1:
|
||||||
|
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
|
||||||
|
else:
|
||||||
|
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
|
||||||
|
|
||||||
|
# Per-block (single)
|
||||||
|
for b in range(num_single):
|
||||||
|
prefix = f"single_transformer_blocks.{b}."
|
||||||
|
for okey, dvals in DIFFUSERS_MAP.items():
|
||||||
|
if not okey.startswith("single_blocks."):
|
||||||
|
continue
|
||||||
|
dkeys = [prefix + v for v in dvals]
|
||||||
|
if not all(src.has(k) for k in dkeys):
|
||||||
|
continue
|
||||||
|
if len(dkeys) == 1:
|
||||||
|
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
|
||||||
|
else:
|
||||||
|
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
|
||||||
|
|
||||||
|
# Globals (non-block)
|
||||||
|
for okey, dvals in DIFFUSERS_MAP.items():
|
||||||
|
if okey.startswith(("double_blocks.", "single_blocks.")):
|
||||||
|
continue
|
||||||
|
dkeys = dvals
|
||||||
|
if not all(src.has(k) for k in dkeys):
|
||||||
|
continue
|
||||||
|
if len(dkeys) == 1:
|
||||||
|
orig[okey] = src.get(dkeys[0])
|
||||||
|
else:
|
||||||
|
orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
|
||||||
|
|
||||||
|
# Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
|
||||||
|
if "final_layer.adaLN_modulation.1.weight" in orig:
|
||||||
|
orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
|
||||||
|
orig["final_layer.adaLN_modulation.1.weight"]
|
||||||
|
)
|
||||||
|
if "final_layer.adaLN_modulation.1.bias" in orig:
|
||||||
|
orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
|
||||||
|
orig["final_layer.adaLN_modulation.1.bias"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional FP8 variants (experimental; not required for ComfyUI/BFL)
|
||||||
|
if args.fp8 or args.fp8_scaled:
|
||||||
|
dtype = torch.float8_e4m3fn # noqa
|
||||||
|
minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||||
|
|
||||||
|
def stochastic_round_to(t):
|
||||||
|
t = t.float().clamp(minv, maxv)
|
||||||
|
lower = torch.floor(t * 256) / 256
|
||||||
|
upper = torch.ceil(t * 256) / 256
|
||||||
|
prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
|
||||||
|
rnd = torch.rand_like(t)
|
||||||
|
out = torch.where(rnd < prob, upper, lower)
|
||||||
|
return out.to(dtype)
|
||||||
|
|
||||||
|
def scale_to_8bit(weight, target_max=416.0):
|
||||||
|
absmax = weight.abs().max()
|
||||||
|
scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
|
||||||
|
scaled = (weight / scale).clamp(minv, maxv).to(dtype)
|
||||||
|
return scaled, scale
|
||||||
|
|
||||||
|
scales = {}
|
||||||
|
for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
|
||||||
|
t = orig[k]
|
||||||
|
if args.fp8:
|
||||||
|
orig[k] = stochastic_round_to(t)
|
||||||
|
else:
|
||||||
|
if k.endswith(".weight") and t.dim() == 2:
|
||||||
|
qt, s = scale_to_8bit(t)
|
||||||
|
orig[k] = qt
|
||||||
|
scales[k[:-len(".weight")] + ".scale_weight"] = s
|
||||||
|
else:
|
||||||
|
orig[k] = t.clamp(minv, maxv).to(dtype)
|
||||||
|
if args.fp8_scaled:
|
||||||
|
orig.update(scales)
|
||||||
|
orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
|
||||||
|
else:
|
||||||
|
# Default: save in bfloat16
|
||||||
|
for k in list(orig.keys()):
|
||||||
|
orig[k] = orig[k].to(torch.bfloat16).cpu()
|
||||||
|
|
||||||
|
out_path = Path(args.output_path)
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
meta = OrderedDict()
|
||||||
|
meta["format"] = "pt"
|
||||||
|
meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
|
||||||
|
print(f"Saving transformer to: {out_path}")
|
||||||
|
safetensors.torch.save_file(orig, str(out_path), metadata=meta)
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
shared/inpainting/__init__.py
Normal file
0
shared/inpainting/__init__.py
Normal file
240
shared/inpainting/lanpaint.py
Normal file
240
shared/inpainting/lanpaint.py
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
import torch
|
||||||
|
from .utils import *
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
|
||||||
|
|
||||||
|
def _pack_latents(latents):
|
||||||
|
batch_size, num_channels_latents, _, height, width = latents.shape
|
||||||
|
|
||||||
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def _unpack_latents(latents, height, width, vae_scale_factor=8):
|
||||||
|
batch_size, num_patches, channels = latents.shape
|
||||||
|
|
||||||
|
height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||||
|
|
||||||
|
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
||||||
|
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||||
|
|
||||||
|
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
class LanPaint():
|
||||||
|
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
|
||||||
|
self.n_steps = NSteps
|
||||||
|
self.chara_lamb = Lambda
|
||||||
|
self.IS_FLUX = IS_FLUX
|
||||||
|
self.IS_FLOW = IS_FLOW
|
||||||
|
self.step_size = StepSize
|
||||||
|
self.friction = Friction
|
||||||
|
self.chara_beta = Beta
|
||||||
|
self.img_dim_size = None
|
||||||
|
def add_none_dims(self, array):
|
||||||
|
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
|
||||||
|
index = (slice(None),) + (None,) * (self.img_dim_size-1)
|
||||||
|
return array[index]
|
||||||
|
def remove_none_dims(self, array):
|
||||||
|
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
|
||||||
|
index = (slice(None),) + (0,) * (self.img_dim_size-1)
|
||||||
|
return array[index]
|
||||||
|
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
|
||||||
|
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
||||||
|
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
||||||
|
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
||||||
|
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
||||||
|
self.height = height
|
||||||
|
self.width = width
|
||||||
|
self.vae_scale_factor = vae_scale_factor
|
||||||
|
self.img_dim_size = len(x.shape)
|
||||||
|
self.latent_image = latent_image
|
||||||
|
self.noise = noise
|
||||||
|
if n_steps is None:
|
||||||
|
n_steps = self.n_steps
|
||||||
|
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
|
||||||
|
out = _pack_latents(out)
|
||||||
|
return out
|
||||||
|
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
|
||||||
|
if IS_FLUX:
|
||||||
|
cfg_BIG = 1.0
|
||||||
|
|
||||||
|
def double_denoise(latents, t):
|
||||||
|
latents = _pack_latents(latents)
|
||||||
|
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
|
||||||
|
if noise_pred == None: return None, None
|
||||||
|
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
|
||||||
|
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
|
||||||
|
if true_cfg_scale == cfg_BIG:
|
||||||
|
predict_big = predict_std
|
||||||
|
else:
|
||||||
|
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
|
||||||
|
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
|
||||||
|
return predict_std, predict_big
|
||||||
|
|
||||||
|
if len(sigma.shape) == 0:
|
||||||
|
sigma = torch.tensor([sigma.item()])
|
||||||
|
latent_mask = 1 - latent_mask
|
||||||
|
if IS_FLUX or IS_FLOW:
|
||||||
|
Flow_t = sigma
|
||||||
|
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
|
||||||
|
VE_Sigma = Flow_t / (1 - Flow_t)
|
||||||
|
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
|
||||||
|
else:
|
||||||
|
VE_Sigma = sigma
|
||||||
|
abt = 1/( 1+VE_Sigma**2 )
|
||||||
|
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
|
||||||
|
# VE_Sigma, abt, Flow_t = current_times
|
||||||
|
current_times = (VE_Sigma, abt, Flow_t)
|
||||||
|
|
||||||
|
step_size = self.step_size * (1 - abt)
|
||||||
|
step_size = self.add_none_dims(step_size)
|
||||||
|
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
|
||||||
|
# This is the replace step
|
||||||
|
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
|
||||||
|
|
||||||
|
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
|
||||||
|
x = x * (1 - latent_mask) + noisy_image * latent_mask
|
||||||
|
|
||||||
|
if IS_FLUX or IS_FLOW:
|
||||||
|
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
|
||||||
|
else:
|
||||||
|
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
|
||||||
|
|
||||||
|
############ LanPaint Iterations Start ###############
|
||||||
|
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
|
||||||
|
args = None
|
||||||
|
for i in range(n_steps):
|
||||||
|
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
|
||||||
|
if score_func is None: return None
|
||||||
|
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
|
||||||
|
if IS_FLUX or IS_FLOW:
|
||||||
|
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
|
||||||
|
else:
|
||||||
|
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
|
||||||
|
############ LanPaint Iterations End ###############
|
||||||
|
# out is x_0
|
||||||
|
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
||||||
|
# out = out * (1-latent_mask) + self.latent_image * latent_mask
|
||||||
|
# return out
|
||||||
|
return x
|
||||||
|
|
||||||
|
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
|
||||||
|
|
||||||
|
lamb = self.chara_lamb
|
||||||
|
if self.IS_FLUX or self.IS_FLOW:
|
||||||
|
# compute t for flow model, with a small epsilon compensating for numerical error.
|
||||||
|
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
|
||||||
|
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
|
||||||
|
if x_0 is None: return None
|
||||||
|
else:
|
||||||
|
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
|
||||||
|
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
|
||||||
|
if x_0 is None: return None
|
||||||
|
|
||||||
|
score_x = -(x_t - x_0)
|
||||||
|
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
|
||||||
|
return score_x * (1 - mask) + score_y * mask
|
||||||
|
def sigma_x(self, abt):
|
||||||
|
# the time scale for the x_t update
|
||||||
|
return abt**0
|
||||||
|
def sigma_y(self, abt):
|
||||||
|
beta = self.chara_beta * abt ** 0
|
||||||
|
return beta
|
||||||
|
|
||||||
|
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
|
||||||
|
# prepare the step size and time parameters
|
||||||
|
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
|
||||||
|
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
|
||||||
|
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
|
||||||
|
# print('mask',mask.device)
|
||||||
|
if torch.mean(dtx) <= 0.:
|
||||||
|
return x_t, args
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Compute the Langevin dynamics update in variance perserving notation
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
#x0 = self.x0_evalutation(x_t, score, sigma, args)
|
||||||
|
#C = abt**0.5 * x0 / (1-abt)
|
||||||
|
A = A_x * (1-mask) + A_y * mask
|
||||||
|
D = D_x * (1-mask) + D_y * mask
|
||||||
|
dt = dtx * (1-mask) + dty * mask
|
||||||
|
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
|
||||||
|
|
||||||
|
|
||||||
|
def Coef_C(x_t):
|
||||||
|
x0 = self.x0_evalutation(x_t, score, sigma, args)
|
||||||
|
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
|
||||||
|
return C
|
||||||
|
def advance_time(x_t, v, dt, Gamma, A, C, D):
|
||||||
|
dtype = x_t.dtype
|
||||||
|
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
|
||||||
|
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
|
||||||
|
x_t, v = osc.dynamics(x_t, v, dt )
|
||||||
|
x_t = x_t.to(dtype)
|
||||||
|
v = v.to(dtype)
|
||||||
|
return x_t, v
|
||||||
|
if args is None:
|
||||||
|
#v = torch.zeros_like(x_t)
|
||||||
|
v = None
|
||||||
|
C = Coef_C(x_t)
|
||||||
|
#print(torch.squeeze(dtx), torch.squeeze(dty))
|
||||||
|
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
|
||||||
|
else:
|
||||||
|
v, C = args
|
||||||
|
|
||||||
|
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
|
||||||
|
|
||||||
|
C_new = Coef_C(x_t)
|
||||||
|
v = v + Gamma**0.5 * ( C_new - C) *dt
|
||||||
|
|
||||||
|
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
|
||||||
|
|
||||||
|
C = C_new
|
||||||
|
|
||||||
|
return x_t, (v, C)
|
||||||
|
|
||||||
|
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Unpack current times parameters (sigma and abt)
|
||||||
|
sigma, abt, flow_t = current_times
|
||||||
|
sigma = self.add_none_dims(sigma)
|
||||||
|
abt = self.add_none_dims(abt)
|
||||||
|
# Compute time step (dtx, dty) for x and y branches.
|
||||||
|
dtx = 2 * step_size * sigma_x
|
||||||
|
dty = 2 * step_size * sigma_y
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Define friction parameter Gamma_hat for each branch.
|
||||||
|
# Using dtx**0 provides a tensor of the proper device/dtype.
|
||||||
|
|
||||||
|
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
|
||||||
|
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
|
||||||
|
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
|
||||||
|
# adjust dt to match denoise-addnoise steps sizes
|
||||||
|
Gamma_hat_x /= 2.
|
||||||
|
Gamma_hat_y /= 2.
|
||||||
|
A_t_x = (1) / ( 1 - abt ) * dtx / 2
|
||||||
|
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
|
||||||
|
|
||||||
|
|
||||||
|
A_x = A_t_x / (dtx/2)
|
||||||
|
A_y = A_t_y / (dty/2)
|
||||||
|
Gamma_x = Gamma_hat_x / (dtx/2)
|
||||||
|
Gamma_y = Gamma_hat_y / (dty/2)
|
||||||
|
|
||||||
|
#D_x = (2 * (1 + sigma**2) )**0.5
|
||||||
|
#D_y = (2 * (1 + sigma**2) )**0.5
|
||||||
|
D_x = (2 * abt**0 )**0.5
|
||||||
|
D_y = (2 * abt**0 )**0.5
|
||||||
|
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def x0_evalutation(self, x_t, score, sigma, args):
|
||||||
|
x0 = x_t + score(x_t)
|
||||||
|
return x0
|
||||||
301
shared/inpainting/utils.py
Normal file
301
shared/inpainting/utils.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
import torch
|
||||||
|
def epxm1_x(x):
|
||||||
|
# Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
|
||||||
|
result = torch.special.expm1(x) / x
|
||||||
|
# replace NaN or inf values with 0
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
mask = torch.abs(x) < 1e-2
|
||||||
|
result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
|
||||||
|
return result
|
||||||
|
def epxm1mx_x2(x):
|
||||||
|
# Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
|
||||||
|
result = (torch.special.expm1(x) - x) / x**2
|
||||||
|
# replace NaN or inf values with 0
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
mask = torch.abs(x**2) < 1e-2
|
||||||
|
result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def expm1mxmhx2_x3(x):
|
||||||
|
# Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
|
||||||
|
result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
|
||||||
|
# replace NaN or inf values with 0
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
mask = torch.abs(x**3) < 1e-2
|
||||||
|
result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exp_1mcosh_GD(gamma_t, delta):
|
||||||
|
"""
|
||||||
|
Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
gamma_t: Γ*t term (could be a scalar or tensor)
|
||||||
|
delta: Δ term (could be a scalar or tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the computation with numerical stability handling
|
||||||
|
"""
|
||||||
|
# Main computation
|
||||||
|
is_positive = delta > 0
|
||||||
|
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
|
||||||
|
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
|
||||||
|
numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
|
||||||
|
numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) )
|
||||||
|
numerator = torch.where(is_positive, numerator_pos, numerator_neg)
|
||||||
|
result = numerator / (delta * gamma_t**2 )
|
||||||
|
# Handle NaN/inf cases
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
# Handle numerical instability for small delta
|
||||||
|
mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
|
||||||
|
taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
|
||||||
|
result = torch.where(mask, taylor, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exp_sinh_GsqrtD(gamma_t, delta):
|
||||||
|
"""
|
||||||
|
Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
gamma_t: Γ*t term (could be a scalar or tensor)
|
||||||
|
delta: Δ term (could be a scalar or tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the computation with numerical stability handling
|
||||||
|
"""
|
||||||
|
# Main computation
|
||||||
|
is_positive = delta > 0
|
||||||
|
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
|
||||||
|
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
|
||||||
|
numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
|
||||||
|
denominator_pos = gamma_t_sqrt_delta
|
||||||
|
result_pos = numerator_pos / gamma_t_sqrt_delta
|
||||||
|
result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
|
||||||
|
|
||||||
|
# Taylor expansion for small gamma_t_sqrt_delta
|
||||||
|
mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
|
||||||
|
taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
|
||||||
|
result_pos = torch.where(mask, taylor, result_pos)
|
||||||
|
|
||||||
|
# Handle negative delta
|
||||||
|
result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
|
||||||
|
result = torch.where(is_positive, result_pos, result_neg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exp_cosh(gamma_t, delta):
|
||||||
|
"""
|
||||||
|
Compute e^(-Γt) * cosh(Γt√Δ)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
gamma_t: Γ*t term (could be a scalar or tensor)
|
||||||
|
delta: Δ term (could be a scalar or tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the computation with numerical stability handling
|
||||||
|
"""
|
||||||
|
exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
|
||||||
|
result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
|
||||||
|
return result
|
||||||
|
def exp_sinh_sqrtD(gamma_t, delta):
|
||||||
|
"""
|
||||||
|
Compute e^(-Γt) * sinh(Γt√Δ) / √Δ
|
||||||
|
Parameters:
|
||||||
|
gamma_t: Γ*t term (could be a scalar or tensor)
|
||||||
|
delta: Δ term (could be a scalar or tensor)
|
||||||
|
Returns:
|
||||||
|
Result of the computation with numerical stability handling
|
||||||
|
"""
|
||||||
|
exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
|
||||||
|
result = gamma_t * exp_sinh_GsqrtD_result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def zeta1(gamma_t, delta):
|
||||||
|
# Compute hyperbolic terms and exponential
|
||||||
|
half_gamma_t = gamma_t / 2
|
||||||
|
exp_cosh_term = exp_cosh(half_gamma_t, delta)
|
||||||
|
exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
|
||||||
|
|
||||||
|
|
||||||
|
# Main computation
|
||||||
|
numerator = 1 - (exp_cosh_term + exp_sinh_term)
|
||||||
|
denominator = gamma_t * (1 - delta) / 4
|
||||||
|
result = 1 - numerator / denominator
|
||||||
|
|
||||||
|
# Handle numerical instability
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
|
||||||
|
# Taylor expansion for small x (similar to your epxm1Dx approach)
|
||||||
|
mask = torch.abs(denominator) < 5e-3
|
||||||
|
term1 = epxm1_x(-gamma_t)
|
||||||
|
term2 = epxm1mx_x2(-gamma_t)
|
||||||
|
term3 = expm1mxmhx2_x3(-gamma_t)
|
||||||
|
taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
|
||||||
|
result = torch.where(mask, taylor, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def exp_cosh_minus_terms(gamma_t, delta):
|
||||||
|
"""
|
||||||
|
Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ))
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
gamma_t: Γ*t term (could be a scalar or tensor)
|
||||||
|
delta: Δ term (could be a scalar or tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of the computation with numerical stability handling
|
||||||
|
"""
|
||||||
|
exp_term = torch.exp(-gamma_t)
|
||||||
|
# Compute individual terms
|
||||||
|
exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
|
||||||
|
exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
|
||||||
|
|
||||||
|
#exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
|
||||||
|
# Main computation
|
||||||
|
numerator = exp_cosh_term - exp_cosh_delta_term
|
||||||
|
denominator = gamma_t * (1 - delta)
|
||||||
|
|
||||||
|
result = numerator / denominator
|
||||||
|
|
||||||
|
# Handle numerical instability
|
||||||
|
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
|
||||||
|
|
||||||
|
# Taylor expansion for small gamma_t and delta near 1
|
||||||
|
mask = (torch.abs(denominator) < 1e-1)
|
||||||
|
exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
|
||||||
|
taylor = (
|
||||||
|
gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0)
|
||||||
|
- denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
|
||||||
|
)
|
||||||
|
result = torch.where(mask, taylor, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def zeta2(gamma_t, delta):
|
||||||
|
half_gamma_t = gamma_t / 2
|
||||||
|
return exp_sinh_GsqrtD(half_gamma_t, delta)
|
||||||
|
|
||||||
|
def sig11(gamma_t, delta):
|
||||||
|
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
|
||||||
|
|
||||||
|
|
||||||
|
def Zcoefs(gamma_t, delta):
|
||||||
|
Zeta1 = zeta1(gamma_t, delta)
|
||||||
|
Zeta2 = zeta2(gamma_t, delta)
|
||||||
|
|
||||||
|
sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
|
||||||
|
amplitude = torch.sqrt(sq_total)
|
||||||
|
Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
|
||||||
|
Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5
|
||||||
|
#cterm = exp_cosh_minus_terms(gamma_t, delta)
|
||||||
|
#sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
|
||||||
|
#Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
|
||||||
|
Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
|
||||||
|
|
||||||
|
return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
|
||||||
|
|
||||||
|
def Zcoefs_asymp(gamma_t, delta):
|
||||||
|
A_t = (gamma_t * (1 - delta) )/4
|
||||||
|
return epxm1_x(- 2 * A_t)
|
||||||
|
|
||||||
|
class StochasticHarmonicOscillator:
|
||||||
|
"""
|
||||||
|
Simulates a stochastic harmonic oscillator governed by the equations:
|
||||||
|
dy(t) = q(t) dt
|
||||||
|
dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
|
||||||
|
|
||||||
|
Also define v(t) = q(t) / √Γ, which is numerically more stable.
|
||||||
|
|
||||||
|
Where:
|
||||||
|
y(t) - Position variable
|
||||||
|
q(t) - Velocity variable
|
||||||
|
Γ - Damping coefficient
|
||||||
|
A - Harmonic potential strength
|
||||||
|
C - Constant force term
|
||||||
|
D - Noise amplitude
|
||||||
|
dw(t) - Wiener process (Brownian motion)
|
||||||
|
"""
|
||||||
|
def __init__(self, Gamma, A, C, D):
|
||||||
|
self.Gamma = Gamma
|
||||||
|
self.A = A
|
||||||
|
self.C = C
|
||||||
|
self.D = D
|
||||||
|
self.Delta = 1 - 4 * A / Gamma
|
||||||
|
def sig11(self, gamma_t, delta):
|
||||||
|
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
|
||||||
|
def sig22(self, gamma_t, delta):
|
||||||
|
return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta)
|
||||||
|
def dynamics(self, y0, v0, t):
|
||||||
|
"""
|
||||||
|
Calculates the position and velocity variables at time t.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
y0 (float): Initial position
|
||||||
|
v0 (float): Initial velocity v(0) = q(0) / √Γ
|
||||||
|
t (float): Time at which to evaluate the dynamics
|
||||||
|
Returns:
|
||||||
|
tuple: (y(t), v(t))
|
||||||
|
"""
|
||||||
|
|
||||||
|
dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
|
||||||
|
Delta = self.Delta + dummyzero
|
||||||
|
Gamma_hat = self.Gamma * t + dummyzero
|
||||||
|
A = self.A + dummyzero
|
||||||
|
C = self.C + dummyzero
|
||||||
|
D = self.D + dummyzero
|
||||||
|
Gamma = self.Gamma + dummyzero
|
||||||
|
zeta_1 = zeta1( Gamma_hat, Delta)
|
||||||
|
zeta_2 = zeta2( Gamma_hat, Delta)
|
||||||
|
EE = 1 - Gamma_hat * zeta_2
|
||||||
|
|
||||||
|
if v0 is None:
|
||||||
|
v0 = torch.randn_like(y0) * D / 2 ** 0.5
|
||||||
|
#v0 = (C - A * y0)/Gamma**0.5
|
||||||
|
|
||||||
|
# Calculate mean position and velocity
|
||||||
|
term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
|
||||||
|
y_mean = term1 + y0
|
||||||
|
v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
|
||||||
|
|
||||||
|
cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
|
||||||
|
cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
|
||||||
|
cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
|
||||||
|
|
||||||
|
# sample new position and velocity with multivariate normal distribution
|
||||||
|
|
||||||
|
batch_shape = y0.shape
|
||||||
|
cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
|
||||||
|
cov_matrix[..., 0, 0] = cov_yy
|
||||||
|
cov_matrix[..., 0, 1] = cov_yv
|
||||||
|
cov_matrix[..., 1, 0] = cov_yv # symmetric
|
||||||
|
cov_matrix[..., 1, 1] = cov_vv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Compute the Cholesky decomposition to get scale_tril
|
||||||
|
#scale_tril = torch.linalg.cholesky(cov_matrix)
|
||||||
|
scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
|
||||||
|
tol = 1e-8
|
||||||
|
cov_yy = torch.clamp( cov_yy, min = tol )
|
||||||
|
sd_yy = torch.sqrt( cov_yy )
|
||||||
|
inv_sd_yy = 1/(sd_yy)
|
||||||
|
|
||||||
|
scale_tril[..., 0, 0] = sd_yy
|
||||||
|
scale_tril[..., 0, 1] = 0.
|
||||||
|
scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
|
||||||
|
scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
|
||||||
|
# check if it matches torch.linalg.
|
||||||
|
#assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
|
||||||
|
# Sample correlated noise from multivariate normal
|
||||||
|
mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
|
||||||
|
mean[..., 0] = y_mean
|
||||||
|
mean[..., 1] = v_mean
|
||||||
|
new_yv = torch.distributions.MultivariateNormal(
|
||||||
|
loc=mean,
|
||||||
|
scale_tril=scale_tril
|
||||||
|
).sample()
|
||||||
|
|
||||||
|
return new_yv[...,0], new_yv[...,1]
|
||||||
@ -232,6 +232,9 @@ def save_video(tensor,
|
|||||||
retry=5):
|
retry=5):
|
||||||
"""Save tensor as video with configurable codec and container options."""
|
"""Save tensor as video with configurable codec and container options."""
|
||||||
|
|
||||||
|
if torch.is_tensor(tensor) and len(tensor.shape) == 4:
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
|
||||||
suffix = f'.{container}'
|
suffix = f'.{container}'
|
||||||
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
|
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
|
||||||
if not cache_file.endswith(suffix):
|
if not cache_file.endswith(suffix):
|
||||||
|
|||||||
110
shared/utils/download.py
Normal file
110
shared/utils/download.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import sys, time
|
||||||
|
|
||||||
|
# Global variables to track download progress
|
||||||
|
_start_time = None
|
||||||
|
_last_time = None
|
||||||
|
_last_downloaded = 0
|
||||||
|
_speed_history = []
|
||||||
|
_update_interval = 0.5 # Update speed every 0.5 seconds
|
||||||
|
|
||||||
|
def progress_hook(block_num, block_size, total_size, filename=None):
|
||||||
|
"""
|
||||||
|
Simple progress bar hook for urlretrieve
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_num: Number of blocks downloaded so far
|
||||||
|
block_size: Size of each block in bytes
|
||||||
|
total_size: Total size of the file in bytes
|
||||||
|
filename: Name of the file being downloaded (optional)
|
||||||
|
"""
|
||||||
|
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
downloaded = block_num * block_size
|
||||||
|
|
||||||
|
# Initialize timing on first call
|
||||||
|
if _start_time is None or block_num == 0:
|
||||||
|
_start_time = current_time
|
||||||
|
_last_time = current_time
|
||||||
|
_last_downloaded = 0
|
||||||
|
_speed_history = []
|
||||||
|
|
||||||
|
# Calculate download speed only at specified intervals
|
||||||
|
speed = 0
|
||||||
|
if current_time - _last_time >= _update_interval:
|
||||||
|
if _last_time > 0:
|
||||||
|
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
|
||||||
|
_speed_history.append(current_speed)
|
||||||
|
# Keep only last 5 speed measurements for smoothing
|
||||||
|
if len(_speed_history) > 5:
|
||||||
|
_speed_history.pop(0)
|
||||||
|
# Average the recent speeds for smoother display
|
||||||
|
speed = sum(_speed_history) / len(_speed_history)
|
||||||
|
|
||||||
|
_last_time = current_time
|
||||||
|
_last_downloaded = downloaded
|
||||||
|
elif _speed_history:
|
||||||
|
# Use the last calculated average speed
|
||||||
|
speed = sum(_speed_history) / len(_speed_history)
|
||||||
|
# Format file sizes and speed
|
||||||
|
def format_bytes(bytes_val):
|
||||||
|
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val:.1f}{unit}"
|
||||||
|
bytes_val /= 1024
|
||||||
|
return f"{bytes_val:.1f}TB"
|
||||||
|
|
||||||
|
file_display = filename if filename else "Unknown file"
|
||||||
|
|
||||||
|
if total_size <= 0:
|
||||||
|
# If total size is unknown, show downloaded bytes
|
||||||
|
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
|
||||||
|
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
|
||||||
|
# Clear any trailing characters by padding with spaces
|
||||||
|
sys.stdout.write(line.ljust(80))
|
||||||
|
sys.stdout.flush()
|
||||||
|
return
|
||||||
|
|
||||||
|
downloaded = block_num * block_size
|
||||||
|
percent = min(100, (downloaded / total_size) * 100)
|
||||||
|
|
||||||
|
# Create progress bar (40 characters wide to leave room for other info)
|
||||||
|
bar_length = 40
|
||||||
|
filled = int(bar_length * percent / 100)
|
||||||
|
bar = '█' * filled + '░' * (bar_length - filled)
|
||||||
|
|
||||||
|
# Format file sizes and speed
|
||||||
|
def format_bytes(bytes_val):
|
||||||
|
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val:.1f}{unit}"
|
||||||
|
bytes_val /= 1024
|
||||||
|
return f"{bytes_val:.1f}TB"
|
||||||
|
|
||||||
|
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
|
||||||
|
|
||||||
|
# Display progress with filename first
|
||||||
|
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
|
||||||
|
# Clear any trailing characters by padding with spaces
|
||||||
|
sys.stdout.write(line.ljust(100))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Print newline when complete
|
||||||
|
if percent >= 100:
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Wrapper function to include filename in progress hook
|
||||||
|
def create_progress_hook(filename):
|
||||||
|
"""Creates a progress hook with the filename included"""
|
||||||
|
global _start_time, _last_time, _last_downloaded, _speed_history
|
||||||
|
# Reset timing variables for new download
|
||||||
|
_start_time = None
|
||||||
|
_last_time = None
|
||||||
|
_last_downloaded = 0
|
||||||
|
_speed_history = []
|
||||||
|
|
||||||
|
def hook(block_num, block_size, total_size):
|
||||||
|
return progress_hook(block_num, block_size, total_size, filename)
|
||||||
|
return hook
|
||||||
|
|
||||||
|
|
||||||
@ -1,4 +1,3 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
@ -176,8 +175,9 @@ def remove_background(img, session=None):
|
|||||||
def convert_image_to_tensor(image):
|
def convert_image_to_tensor(image):
|
||||||
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
|
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
|
||||||
|
|
||||||
def convert_tensor_to_image(t, frame_no = -1):
|
def convert_tensor_to_image(t, frame_no = 0):
|
||||||
t = t[:, frame_no] if frame_no >= 0 else t
|
if len(t.shape) == 4:
|
||||||
|
t = t[:, frame_no]
|
||||||
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||||
|
|
||||||
def save_image(tensor_image, name, frame_no = -1):
|
def save_image(tensor_image, name, frame_no = -1):
|
||||||
@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
|
|||||||
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||||
return image, new_height, new_width
|
return image, new_height, new_width
|
||||||
|
|
||||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
|
||||||
if rm_background:
|
if rm_background:
|
||||||
session = new_session()
|
session = new_session()
|
||||||
|
|
||||||
output_list =[]
|
output_list =[]
|
||||||
|
output_mask_list =[]
|
||||||
for i, img in enumerate(img_list):
|
for i, img in enumerate(img_list):
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
|
resized_mask = None
|
||||||
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||||
if outpainting_dims is not None:
|
if outpainting_dims is not None and background_ref_outpainted:
|
||||||
resized_image =img
|
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
|
||||||
elif img.size != (budget_width, budget_height):
|
elif img.size != (budget_width, budget_height):
|
||||||
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
|
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
|
||||||
else:
|
else:
|
||||||
@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
|||||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
||||||
return output_list
|
output_mask_list.append(resized_mask)
|
||||||
|
return output_list, output_mask_list
|
||||||
|
|
||||||
|
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
|
||||||
|
from shared.utils.utils import save_image
|
||||||
|
inpaint_color = canvas_tf_bg / 127.5 - 1
|
||||||
|
|
||||||
|
ref_width, ref_height = ref_img.size
|
||||||
|
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
||||||
def str2bool(v):
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||||
"""
|
canvas = torch.zeros_like(ref_img) if return_mask else None
|
||||||
Convert a string to a boolean.
|
|
||||||
|
|
||||||
Supported true values: 'yes', 'true', 't', 'y', '1'
|
|
||||||
Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
v (str): String to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Converted boolean value.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
|
||||||
"""
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return v
|
|
||||||
v_lower = v.lower()
|
|
||||||
if v_lower in ('yes', 'true', 't', 'y', '1'):
|
|
||||||
return True
|
|
||||||
elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
|
||||||
return False
|
|
||||||
else:
|
else:
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
if outpainting_dims != None:
|
||||||
|
final_height, final_width = image_size
|
||||||
|
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
||||||
import sys, time
|
else:
|
||||||
|
canvas_height, canvas_width = image_size
|
||||||
# Global variables to track download progress
|
if full_frame:
|
||||||
_start_time = None
|
new_height = canvas_height
|
||||||
_last_time = None
|
new_width = canvas_width
|
||||||
_last_downloaded = 0
|
top = left = 0
|
||||||
_speed_history = []
|
else:
|
||||||
_update_interval = 0.5 # Update speed every 0.5 seconds
|
# if fill_max and (canvas_height - new_height) < 16:
|
||||||
|
# new_height = canvas_height
|
||||||
def progress_hook(block_num, block_size, total_size, filename=None):
|
# if fill_max and (canvas_width - new_width) < 16:
|
||||||
"""
|
# new_width = canvas_width
|
||||||
Simple progress bar hook for urlretrieve
|
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||||
|
new_height = int(ref_height * scale)
|
||||||
Args:
|
new_width = int(ref_width * scale)
|
||||||
block_num: Number of blocks downloaded so far
|
top = (canvas_height - new_height) // 2
|
||||||
block_size: Size of each block in bytes
|
left = (canvas_width - new_width) // 2
|
||||||
total_size: Total size of the file in bytes
|
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||||
filename: Name of the file being downloaded (optional)
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||||
"""
|
if outpainting_dims != None:
|
||||||
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
|
canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
|
||||||
|
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
|
||||||
current_time = time.time()
|
else:
|
||||||
downloaded = block_num * block_size
|
canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
|
||||||
|
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
|
||||||
# Initialize timing on first call
|
ref_img = canvas
|
||||||
if _start_time is None or block_num == 0:
|
canvas = None
|
||||||
_start_time = current_time
|
if return_mask:
|
||||||
_last_time = current_time
|
if outpainting_dims != None:
|
||||||
_last_downloaded = 0
|
canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
|
||||||
_speed_history = []
|
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
|
||||||
|
else:
|
||||||
# Calculate download speed only at specified intervals
|
canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
|
||||||
speed = 0
|
canvas[:, :, top:top + new_height, left:left + new_width] = 0
|
||||||
if current_time - _last_time >= _update_interval:
|
canvas = canvas.to(device)
|
||||||
if _last_time > 0:
|
if return_image:
|
||||||
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
|
return convert_tensor_to_image(ref_img), canvas
|
||||||
_speed_history.append(current_speed)
|
|
||||||
# Keep only last 5 speed measurements for smoothing
|
return ref_img.to(device), canvas
|
||||||
if len(_speed_history) > 5:
|
|
||||||
_speed_history.pop(0)
|
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
|
||||||
# Average the recent speeds for smoother display
|
src_videos, src_masks = [], []
|
||||||
speed = sum(_speed_history) / len(_speed_history)
|
inpaint_color = guide_inpaint_color/127.5 - 1
|
||||||
|
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
|
||||||
_last_time = current_time
|
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
|
||||||
_last_downloaded = downloaded
|
src_video = src_mask = None
|
||||||
elif _speed_history:
|
if cur_video_guide is not None:
|
||||||
# Use the last calculated average speed
|
src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
|
||||||
speed = sum(_speed_history) / len(_speed_history)
|
if cur_video_mask is not None and any_mask:
|
||||||
# Format file sizes and speed
|
src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
|
||||||
def format_bytes(bytes_val):
|
if pre_video_guide is not None and not extract_guide_from_window_start:
|
||||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
|
||||||
if bytes_val < 1024:
|
if any_mask:
|
||||||
return f"{bytes_val:.1f}{unit}"
|
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
|
||||||
bytes_val /= 1024
|
if src_video is None:
|
||||||
return f"{bytes_val:.1f}TB"
|
if any_guide_padding:
|
||||||
|
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
|
||||||
file_display = filename if filename else "Unknown file"
|
if any_mask:
|
||||||
|
src_mask = torch.zeros_like(src_video[0:1])
|
||||||
if total_size <= 0:
|
elif src_video.shape[1] < current_video_length:
|
||||||
# If total size is unknown, show downloaded bytes
|
if any_guide_padding:
|
||||||
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
|
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||||
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
|
if cur_video_mask is not None and any_mask:
|
||||||
# Clear any trailing characters by padding with spaces
|
src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
|
||||||
sys.stdout.write(line.ljust(80))
|
else:
|
||||||
sys.stdout.flush()
|
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
|
||||||
return
|
src_video = src_video[:, :new_num_frames]
|
||||||
|
if any_mask:
|
||||||
downloaded = block_num * block_size
|
src_mask = src_mask[:, :new_num_frames]
|
||||||
percent = min(100, (downloaded / total_size) * 100)
|
|
||||||
|
for k, keep in enumerate(keep_video_guide_frames):
|
||||||
# Create progress bar (40 characters wide to leave room for other info)
|
if not keep:
|
||||||
bar_length = 40
|
pos = prepend_count + k
|
||||||
filled = int(bar_length * percent / 100)
|
src_video[:, pos:pos+1] = inpaint_color
|
||||||
bar = '█' * filled + '░' * (bar_length - filled)
|
src_mask[:, pos:pos+1] = 1
|
||||||
|
|
||||||
# Format file sizes and speed
|
for k, frame in enumerate(inject_frames):
|
||||||
def format_bytes(bytes_val):
|
if frame != None:
|
||||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
pos = prepend_count + k
|
||||||
if bytes_val < 1024:
|
src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
|
||||||
return f"{bytes_val:.1f}{unit}"
|
|
||||||
bytes_val /= 1024
|
src_videos.append(src_video)
|
||||||
return f"{bytes_val:.1f}TB"
|
src_masks.append(src_mask)
|
||||||
|
return src_videos, src_masks
|
||||||
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
|
|
||||||
|
|
||||||
# Display progress with filename first
|
|
||||||
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
|
|
||||||
# Clear any trailing characters by padding with spaces
|
|
||||||
sys.stdout.write(line.ljust(100))
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
# Print newline when complete
|
|
||||||
if percent >= 100:
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Wrapper function to include filename in progress hook
|
|
||||||
def create_progress_hook(filename):
|
|
||||||
"""Creates a progress hook with the filename included"""
|
|
||||||
global _start_time, _last_time, _last_downloaded, _speed_history
|
|
||||||
# Reset timing variables for new download
|
|
||||||
_start_time = None
|
|
||||||
_last_time = None
|
|
||||||
_last_downloaded = 0
|
|
||||||
_speed_history = []
|
|
||||||
|
|
||||||
def hook(block_num, block_size, total_size):
|
|
||||||
return progress_hook(block_num, block_size, total_size, filename)
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
228
wgp.py
228
wgp.py
@ -394,7 +394,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|||||||
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
|
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
|
||||||
if "F" in video_prompt_type:
|
if "F" in video_prompt_type:
|
||||||
if len(frames_positions.strip()) > 0:
|
if len(frames_positions.strip()) > 0:
|
||||||
positions = frames_positions.split(" ")
|
positions = frames_positions.replace(","," ").split(" ")
|
||||||
for pos_str in positions:
|
for pos_str in positions:
|
||||||
if not pos_str in ["L", "l"] and len(pos_str)>0:
|
if not pos_str in ["L", "l"] and len(pos_str)>0:
|
||||||
if not is_integer(pos_str):
|
if not is_integer(pos_str):
|
||||||
@ -2528,7 +2528,7 @@ def download_models(model_filename = None, model_type= None, module_type = False
|
|||||||
|
|
||||||
|
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
from shared.utils.utils import create_progress_hook
|
from shared.utils.download import create_progress_hook
|
||||||
|
|
||||||
shared_def = {
|
shared_def = {
|
||||||
"repoId" : "DeepBeepMeep/Wan2.1",
|
"repoId" : "DeepBeepMeep/Wan2.1",
|
||||||
@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva
|
|||||||
input_mask = convert_tensor_to_image(full_frame)
|
input_mask = convert_tensor_to_image(full_frame)
|
||||||
|
|
||||||
return input_image, input_mask
|
return input_image, input_mask
|
||||||
|
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
|
||||||
|
if not input_video_path or max_frames <= 0:
|
||||||
|
return None, None
|
||||||
|
pad_frames = 0
|
||||||
|
if start_frame < 0:
|
||||||
|
pad_frames= -start_frame
|
||||||
|
max_frames += start_frame
|
||||||
|
start_frame = 0
|
||||||
|
|
||||||
|
any_mask = input_mask_path != None
|
||||||
|
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
|
||||||
|
if len(video) == 0: return None
|
||||||
|
if any_mask:
|
||||||
|
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
|
||||||
|
frame_height, frame_width, _ = video[0].shape
|
||||||
|
|
||||||
|
num_frames = min(len(video), len(mask_video))
|
||||||
|
if num_frames == 0: return None
|
||||||
|
video, mask_video = video[:num_frames], mask_video[:num_frames]
|
||||||
|
|
||||||
|
from preprocessing.face_preprocessor import FaceProcessor
|
||||||
|
face_processor = FaceProcessor()
|
||||||
|
|
||||||
|
face_list = []
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
frame = video[frame_idx].cpu().numpy()
|
||||||
|
# video[frame_idx] = None
|
||||||
|
if any_mask:
|
||||||
|
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy())
|
||||||
|
# mask_video[frame_idx] = None
|
||||||
|
if (frame_width, frame_height) != mask.size:
|
||||||
|
mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS)
|
||||||
|
mask = np.array(mask)
|
||||||
|
alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
|
||||||
|
alpha_mask[mask > 127] = 1
|
||||||
|
frame = frame * alpha_mask
|
||||||
|
frame = Image.fromarray(frame)
|
||||||
|
face = face_processor.process(frame, resize_to=size, face_crop_scale = 1)
|
||||||
|
face_list.append(face)
|
||||||
|
|
||||||
|
face_processor = None
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w
|
||||||
|
if pad_frames > 0:
|
||||||
|
face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2)
|
||||||
|
|
||||||
|
if args.save_masks:
|
||||||
|
from preprocessing.dwpose.pose import save_one_video
|
||||||
|
saved_faces_frames = [np.array(face) for face in face_list ]
|
||||||
|
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
|
||||||
|
return face_tensor
|
||||||
|
|
||||||
|
|
||||||
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
|
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
|
||||||
|
|
||||||
@ -3742,7 +3796,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
box = [xmin, ymin, xmax, ymax]
|
box = [xmin, ymin, xmax, ymax]
|
||||||
box = [int(x) for x in box]
|
box = [int(x) for x in box]
|
||||||
return box
|
return box
|
||||||
|
inpaint_color = int(inpaint_color)
|
||||||
|
pad_frames = 0
|
||||||
|
if start_frame < 0:
|
||||||
|
pad_frames= -start_frame
|
||||||
|
max_frames += start_frame
|
||||||
|
start_frame = 0
|
||||||
|
|
||||||
if not input_video_path or max_frames <= 0:
|
if not input_video_path or max_frames <= 0:
|
||||||
return None, None
|
return None, None
|
||||||
any_mask = input_mask_path != None
|
any_mask = input_mask_path != None
|
||||||
@ -3909,6 +3969,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
|
|||||||
preproc_outside = None
|
preproc_outside = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
if pad_frames > 0:
|
||||||
|
masked_frames = masked_frames[0] * pad_frames + masked_frames
|
||||||
|
if any_mask: masked_frames = masks[0] * pad_frames + masks
|
||||||
|
|
||||||
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
|
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
|
||||||
|
|
||||||
@ -4646,7 +4709,8 @@ def generate_video(
|
|||||||
current_video_length = video_length
|
current_video_length = video_length
|
||||||
# VAE Tiling
|
# VAE Tiling
|
||||||
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
||||||
|
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
|
||||||
|
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
|
||||||
i2v = test_class_i2v(model_type)
|
i2v = test_class_i2v(model_type)
|
||||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||||
t2v = base_model_type in ["t2v"]
|
t2v = base_model_type in ["t2v"]
|
||||||
@ -4662,6 +4726,7 @@ def generate_video(
|
|||||||
multitalk = model_def.get("multitalk_class", False)
|
multitalk = model_def.get("multitalk_class", False)
|
||||||
standin = model_def.get("standin_class", False)
|
standin = model_def.get("standin_class", False)
|
||||||
infinitetalk = base_model_type in ["infinitetalk"]
|
infinitetalk = base_model_type in ["infinitetalk"]
|
||||||
|
animate = base_model_type in ["animate"]
|
||||||
|
|
||||||
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
||||||
from models.wan.multitalk.multitalk import parse_speakers_locations
|
from models.wan.multitalk.multitalk import parse_speakers_locations
|
||||||
@ -4822,7 +4887,6 @@ def generate_video(
|
|||||||
repeat_no = 0
|
repeat_no = 0
|
||||||
extra_generation = 0
|
extra_generation = 0
|
||||||
initial_total_windows = 0
|
initial_total_windows = 0
|
||||||
|
|
||||||
discard_last_frames = sliding_window_discard_last_frames
|
discard_last_frames = sliding_window_discard_last_frames
|
||||||
default_requested_frames_to_generate = current_video_length
|
default_requested_frames_to_generate = current_video_length
|
||||||
if sliding_window:
|
if sliding_window:
|
||||||
@ -4843,7 +4907,7 @@ def generate_video(
|
|||||||
if repeat_no >= total_generation: break
|
if repeat_no >= total_generation: break
|
||||||
repeat_no +=1
|
repeat_no +=1
|
||||||
gen["repeat_no"] = repeat_no
|
gen["repeat_no"] = repeat_no
|
||||||
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None
|
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
|
||||||
prefix_video = pre_video_frame = None
|
prefix_video = pre_video_frame = None
|
||||||
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
|
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
|
||||||
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
|
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
|
||||||
@ -4899,7 +4963,7 @@ def generate_video(
|
|||||||
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
|
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
|
||||||
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
|
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
|
||||||
|
|
||||||
src_ref_images = image_refs
|
src_ref_images, src_ref_masks = image_refs, None
|
||||||
image_start_tensor = image_end_tensor = None
|
image_start_tensor = image_end_tensor = None
|
||||||
if window_no == 1 and (video_source is not None or image_start is not None):
|
if window_no == 1 and (video_source is not None or image_start is not None):
|
||||||
if image_start is not None:
|
if image_start is not None:
|
||||||
@ -4943,16 +5007,52 @@ def generate_video(
|
|||||||
from models.wan.multitalk.multitalk import get_window_audio_embeddings
|
from models.wan.multitalk.multitalk import get_window_audio_embeddings
|
||||||
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
|
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
|
||||||
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
|
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
|
||||||
if vace:
|
|
||||||
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
||||||
|
frames_positions_list = []
|
||||||
|
if frames_positions is not None and len(frames_positions)> 0:
|
||||||
|
positions = frames_positions.replace(","," ").split(" ")
|
||||||
|
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count)
|
||||||
|
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
|
||||||
|
joker_used = False
|
||||||
|
project_window_no = 1
|
||||||
|
for pos in positions :
|
||||||
|
if len(pos) > 0:
|
||||||
|
if pos in ["L", "l"]:
|
||||||
|
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
|
||||||
|
if cur_end_pos >= last_frame_no and not joker_used:
|
||||||
|
joker_used = True
|
||||||
|
cur_end_pos = last_frame_no -1
|
||||||
|
project_window_no += 1
|
||||||
|
frames_positions_list.append(cur_end_pos)
|
||||||
|
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
|
||||||
|
else:
|
||||||
|
frames_positions_list.append(int(pos)-1 + alignment_shift)
|
||||||
|
frames_positions_list = frames_positions_list[:len(image_refs)]
|
||||||
|
nb_frames_positions = len(frames_positions_list)
|
||||||
|
if nb_frames_positions > 0:
|
||||||
|
frames_to_inject = [None] * (max(frames_positions_list) + 1)
|
||||||
|
for i, pos in enumerate(frames_positions_list):
|
||||||
|
frames_to_inject[pos] = image_refs[i]
|
||||||
|
|
||||||
|
|
||||||
if video_guide is not None:
|
if video_guide is not None:
|
||||||
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
|
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
|
||||||
if len(error) > 0:
|
if len(error) > 0:
|
||||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||||
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
|
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
|
||||||
|
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
|
||||||
|
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
|
||||||
|
guide_frames_extract_count = len(keep_frames_parsed)
|
||||||
|
|
||||||
if vace:
|
if "B" in video_prompt_type:
|
||||||
|
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
|
||||||
|
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
|
||||||
|
if src_faces is not None and src_faces.shape[1] < current_video_length:
|
||||||
|
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
|
||||||
|
|
||||||
|
if vace or animate:
|
||||||
|
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
|
||||||
context_scale = [ control_net_weight]
|
context_scale = [ control_net_weight]
|
||||||
if "V" in video_prompt_type:
|
if "V" in video_prompt_type:
|
||||||
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
|
||||||
@ -4971,10 +5071,10 @@ def generate_video(
|
|||||||
if preprocess_type2 is not None:
|
if preprocess_type2 is not None:
|
||||||
context_scale = [ control_net_weight /2, control_net_weight2 /2]
|
context_scale = [ control_net_weight /2, control_net_weight2 /2]
|
||||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||||
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127
|
inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
|
||||||
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
|
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
|
||||||
if preprocess_type2 != None:
|
if preprocess_type2 != None:
|
||||||
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
|
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
|
||||||
|
|
||||||
if video_guide_processed != None:
|
if video_guide_processed != None:
|
||||||
if sample_fit_canvas != None:
|
if sample_fit_canvas != None:
|
||||||
@ -4985,7 +5085,37 @@ def generate_video(
|
|||||||
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
|
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
|
||||||
if video_mask_processed != None:
|
if video_mask_processed != None:
|
||||||
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
|
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
|
||||||
elif ltxv:
|
|
||||||
|
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
|
||||||
|
|
||||||
|
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
|
||||||
|
any_mask = True
|
||||||
|
any_guide_padding = model_def.get("pad_guide_video", False)
|
||||||
|
from shared.utils.utils import prepare_video_guide_and_mask
|
||||||
|
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
|
||||||
|
[video_mask_processed, video_mask_processed2],
|
||||||
|
pre_video_guide, image_size, current_video_length, latent_size,
|
||||||
|
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
|
||||||
|
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
|
||||||
|
|
||||||
|
src_video, src_video2 = src_videos
|
||||||
|
src_mask, src_mask2 = src_masks
|
||||||
|
if src_video is None:
|
||||||
|
abort = True
|
||||||
|
break
|
||||||
|
if src_faces is not None:
|
||||||
|
if src_faces.shape[1] < src_video.shape[1]:
|
||||||
|
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
|
||||||
|
else:
|
||||||
|
src_faces = src_faces[:, :src_video.shape[1]]
|
||||||
|
if args.save_masks:
|
||||||
|
save_video( src_video, "masked_frames.mp4", fps)
|
||||||
|
if src_video2 is not None:
|
||||||
|
save_video( src_video2, "masked_frames2.mp4", fps)
|
||||||
|
if any_mask:
|
||||||
|
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
|
||||||
|
|
||||||
|
elif ltxv:
|
||||||
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
|
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
|
||||||
status_info = "Extracting " + processes_names[preprocess_type]
|
status_info = "Extracting " + processes_names[preprocess_type]
|
||||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||||
@ -5023,7 +5153,7 @@ def generate_video(
|
|||||||
sample_fit_canvas = None
|
sample_fit_canvas = None
|
||||||
|
|
||||||
else: # video to video
|
else: # video to video
|
||||||
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps)
|
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
|
||||||
if video_guide_processed is None:
|
if video_guide_processed is None:
|
||||||
src_video = pre_video_guide
|
src_video = pre_video_guide
|
||||||
else:
|
else:
|
||||||
@ -5043,29 +5173,6 @@ def generate_video(
|
|||||||
refresh_preview["image_mask"] = new_image_mask
|
refresh_preview["image_mask"] = new_image_mask
|
||||||
|
|
||||||
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
|
||||||
if repeat_no == 1:
|
|
||||||
frames_positions_list = []
|
|
||||||
if frames_positions is not None and len(frames_positions)> 0:
|
|
||||||
positions = frames_positions.split(" ")
|
|
||||||
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0
|
|
||||||
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
|
|
||||||
joker_used = False
|
|
||||||
project_window_no = 1
|
|
||||||
for pos in positions :
|
|
||||||
if len(pos) > 0:
|
|
||||||
if pos in ["L", "l"]:
|
|
||||||
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
|
|
||||||
if cur_end_pos >= last_frame_no and not joker_used:
|
|
||||||
joker_used = True
|
|
||||||
cur_end_pos = last_frame_no -1
|
|
||||||
project_window_no += 1
|
|
||||||
frames_positions_list.append(cur_end_pos)
|
|
||||||
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
|
|
||||||
else:
|
|
||||||
frames_positions_list.append(int(pos)-1 + alignment_shift)
|
|
||||||
frames_positions_list = frames_positions_list[:len(image_refs)]
|
|
||||||
nb_frames_positions = len(frames_positions_list)
|
|
||||||
|
|
||||||
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
||||||
from shared.utils.utils import get_outpainting_full_area_dimensions
|
from shared.utils.utils import get_outpainting_full_area_dimensions
|
||||||
w, h = image_refs[0].size
|
w, h = image_refs[0].size
|
||||||
@ -5089,20 +5196,16 @@ def generate_video(
|
|||||||
if remove_background_images_ref > 0:
|
if remove_background_images_ref > 0:
|
||||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||||
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
|
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
|
||||||
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
|
image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
|
||||||
remove_background_images_ref > 0, any_background_ref,
|
remove_background_images_ref > 0, any_background_ref,
|
||||||
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
|
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
outpainting_dims =outpainting_dims )
|
outpainting_dims =outpainting_dims,
|
||||||
|
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
|
||||||
refresh_preview["image_refs"] = image_refs
|
refresh_preview["image_refs"] = image_refs
|
||||||
|
|
||||||
if nb_frames_positions > 0:
|
|
||||||
frames_to_inject = [None] * (max(frames_positions_list) + 1)
|
|
||||||
for i, pos in enumerate(frames_positions_list):
|
|
||||||
frames_to_inject[pos] = image_refs[i]
|
|
||||||
|
|
||||||
if vace :
|
if vace :
|
||||||
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame]
|
|
||||||
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
|
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
|
||||||
|
|
||||||
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
|
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
|
||||||
@ -5116,7 +5219,7 @@ def generate_video(
|
|||||||
any_background_ref = any_background_ref
|
any_background_ref = any_background_ref
|
||||||
)
|
)
|
||||||
if len(frames_to_inject_parsed) or any_background_ref:
|
if len(frames_to_inject_parsed) or any_background_ref:
|
||||||
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
||||||
if any_background_ref:
|
if any_background_ref:
|
||||||
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
|
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
|
||||||
else:
|
else:
|
||||||
@ -5165,10 +5268,14 @@ def generate_video(
|
|||||||
input_prompt = prompt,
|
input_prompt = prompt,
|
||||||
image_start = image_start_tensor,
|
image_start = image_start_tensor,
|
||||||
image_end = image_end_tensor,
|
image_end = image_end_tensor,
|
||||||
input_frames = src_video,
|
input_frames = src_video,
|
||||||
|
input_frames2 = src_video2,
|
||||||
input_ref_images= src_ref_images,
|
input_ref_images= src_ref_images,
|
||||||
|
input_ref_masks = src_ref_masks,
|
||||||
input_masks = src_mask,
|
input_masks = src_mask,
|
||||||
|
input_masks2 = src_mask2,
|
||||||
input_video= pre_video_guide,
|
input_video= pre_video_guide,
|
||||||
|
input_faces = src_faces,
|
||||||
denoising_strength=denoising_strength,
|
denoising_strength=denoising_strength,
|
||||||
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
|
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
|
||||||
frame_num= (current_video_length // latent_size)* latent_size + 1,
|
frame_num= (current_video_length // latent_size)* latent_size + 1,
|
||||||
@ -5302,6 +5409,7 @@ def generate_video(
|
|||||||
send_cmd("output")
|
send_cmd("output")
|
||||||
else:
|
else:
|
||||||
sample = samples.cpu()
|
sample = samples.cpu()
|
||||||
|
abort = not is_image and sample.shape[1] < current_video_length
|
||||||
# if True: # for testing
|
# if True: # for testing
|
||||||
# torch.save(sample, "output.pt")
|
# torch.save(sample, "output.pt")
|
||||||
# else:
|
# else:
|
||||||
@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
|
|||||||
|
|
||||||
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
|
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
|
||||||
old_video_prompt_type = video_prompt_type
|
old_video_prompt_type = video_prompt_type
|
||||||
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV")
|
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB")
|
||||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
|
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
|
||||||
visible = "V" in video_prompt_type
|
visible = "V" in video_prompt_type
|
||||||
model_type = state["model_type"]
|
model_type = state["model_type"]
|
||||||
@ -7437,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
|
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
|
||||||
image_prompt_type_choices = []
|
image_prompt_type_choices = []
|
||||||
if "T" in image_prompt_types_allowed:
|
if "T" in image_prompt_types_allowed:
|
||||||
image_prompt_type_choices += [("Text Prompt Only", "")]
|
image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")]
|
||||||
if "S" in image_prompt_types_allowed:
|
if "S" in image_prompt_types_allowed:
|
||||||
image_prompt_type_choices += [("Start Video with Image", "S")]
|
image_prompt_type_choices += [("Start Video with Image", "S")]
|
||||||
any_start_image = True
|
any_start_image = True
|
||||||
@ -7516,7 +7624,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
|
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
|
||||||
video_prompt_type_video_guide = gr.Dropdown(
|
video_prompt_type_video_guide = gr.Dropdown(
|
||||||
guide_preprocessing_choices,
|
guide_preprocessing_choices,
|
||||||
value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ),
|
value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ),
|
||||||
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
|
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
|
||||||
)
|
)
|
||||||
any_control_video = True
|
any_control_video = True
|
||||||
@ -7560,13 +7668,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
}
|
}
|
||||||
|
|
||||||
mask_preprocessing_choices = []
|
mask_preprocessing_choices = []
|
||||||
mask_preprocessing_labels = guide_preprocessing.get("labels", {})
|
mask_preprocessing_labels = mask_preprocessing.get("labels", {})
|
||||||
for process_type in mask_preprocessing["selection"]:
|
for process_type in mask_preprocessing["selection"]:
|
||||||
process_label = mask_preprocessing_labels.get(process_type, None)
|
process_label = mask_preprocessing_labels.get(process_type, None)
|
||||||
process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
|
process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
|
||||||
mask_preprocessing_choices.append( (process_label, process_type) )
|
mask_preprocessing_choices.append( (process_label, process_type) )
|
||||||
|
|
||||||
video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed")
|
video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed")
|
||||||
video_prompt_type_video_mask = gr.Dropdown(
|
video_prompt_type_video_mask = gr.Dropdown(
|
||||||
mask_preprocessing_choices,
|
mask_preprocessing_choices,
|
||||||
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
|
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
|
||||||
@ -7591,7 +7699,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
choices= image_ref_choices["choices"],
|
choices= image_ref_choices["choices"],
|
||||||
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
|
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
|
||||||
visible = image_ref_choices.get("visible", True),
|
visible = image_ref_choices.get("visible", True),
|
||||||
label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2
|
label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
|
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
|
||||||
@ -7634,7 +7742,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
|
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
|
||||||
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
|
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
|
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
|
||||||
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
|
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
|
||||||
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
|
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
|
||||||
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
|
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
|
||||||
@ -7649,14 +7757,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
|
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
|
||||||
|
|
||||||
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
|
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
|
||||||
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
|
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "")
|
||||||
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
|
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
|
||||||
|
|
||||||
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
|
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
|
||||||
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
|
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
|
||||||
|
|
||||||
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
|
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
|
||||||
background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects")
|
background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects")
|
||||||
|
|
||||||
remove_background_images_ref = gr.Dropdown(
|
remove_background_images_ref = gr.Dropdown(
|
||||||
choices=[
|
choices=[
|
||||||
@ -7664,7 +7772,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
(background_removal_label, 1),
|
(background_removal_label, 1),
|
||||||
],
|
],
|
||||||
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
|
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
|
||||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
|
label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
|
||||||
)
|
)
|
||||||
|
|
||||||
any_audio_voices_support = any_audio_track(base_model_type)
|
any_audio_voices_support = any_audio_track(base_model_type)
|
||||||
@ -8084,7 +8192,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|||||||
("Aligned to the beginning of the First Window of the new Video Sample", "T"),
|
("Aligned to the beginning of the First Window of the new Video Sample", "T"),
|
||||||
],
|
],
|
||||||
value=filter_letters(video_prompt_type_value, "T"),
|
value=filter_letters(video_prompt_type_value, "T"),
|
||||||
label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue",
|
label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue",
|
||||||
visible = vace or ltxv or t2v or infinitetalk
|
visible = vace or ltxv or t2v or infinitetalk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user