commit in case there is an unrecoverable code hemorragy

This commit is contained in:
DeepBeepMeep 2025-09-22 17:11:25 +02:00
parent fc615ffb3c
commit 84010bd861
17 changed files with 1620 additions and 380 deletions

15
configs/animate.json Normal file
View File

@ -0,0 +1,15 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"motion_encoder_dim": 512
}

14
configs/lucy_edit.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.33.0",
"dim": 3072,
"eps": 1e-06,
"ffn_dim": 14336,
"freq_dim": 256,
"in_dim": 96,
"model_type": "ti2v2_2",
"num_heads": 24,
"num_layers": 30,
"out_dim": 48,
"text_len": 512
}

13
defaults/animate.json Normal file
View File

@ -0,0 +1,13 @@
{
"model": {
"name": "Wan2.2 Animate",
"architecture": "animate",
"description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
],
"group": "wan2_2"
}
}

18
defaults/lucy_edit.json Normal file
View File

@ -0,0 +1,18 @@
{
"model": {
"name": "Wan2.2 Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
],
"group": "wan2_2"
},
"video_length": 81,
"guidance_scale": 5,
"flow_shift": 5,
"num_inference_steps": 30,
"resolution": "1280x720"
}

View File

@ -7,6 +7,7 @@
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2" "group": "wan2_2"
}, },
"prompt" : "Put the person into a clown outfit.",
"video_length": 121, "video_length": 121,
"guidance_scale": 1, "guidance_scale": 1,
"flow_shift": 3, "flow_shift": 3,

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mmgp import offload from mmgp import offload
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
@ -387,7 +386,8 @@ class QwenImagePipeline(): #DiffusionPipeline
return latent_image_ids.to(device=device, dtype=dtype) return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
@ -479,7 +479,7 @@ class QwenImagePipeline(): #DiffusionPipeline
height = 2 * (int(height) // (self.vae_scale_factor * 2)) height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, 1, num_channels_latents, height, width) shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None image_latents = None
if image is not None: if image is not None:
@ -499,10 +499,7 @@ class QwenImagePipeline(): #DiffusionPipeline
else: else:
image_latents = torch.cat([image_latents], dim=0) image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:] image_latents = self._pack_latents(image_latents)
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
@ -511,7 +508,7 @@ class QwenImagePipeline(): #DiffusionPipeline
) )
if latents is None: if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents)
else: else:
latents = latents.to(device=device, dtype=dtype) latents = latents.to(device=device, dtype=dtype)
@ -713,11 +710,12 @@ class QwenImagePipeline(): #DiffusionPipeline
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
height, width = image_height, image_width height, width = image_height, image_width
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
image_mask_latents = self._pack_latents(image_mask_latents)
prompt_image = image prompt_image = image
if image.size != (image_width, image_height): if image.size != (image_width, image_height):
@ -822,6 +820,7 @@ class QwenImagePipeline(): #DiffusionPipeline
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
) )
morph, first_step = False, 0 morph, first_step = False, 0
lanpaint_proc = None
if image_mask_latents is not None: if image_mask_latents is not None:
randn = torch.randn_like(original_image_latents) randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.: if denoising_strength < 1.:
@ -833,7 +832,8 @@ class QwenImagePipeline(): #DiffusionPipeline
timesteps = timesteps[first_step:] timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:] self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
# from shared.inpainting.lanpaint import LanPaint
# lanpaint_proc = LanPaint()
# 6. Denoising loop # 6. Denoising loop
self.scheduler.set_begin_index(0) self.scheduler.set_begin_index(0)
updated_num_steps= len(timesteps) updated_num_steps= len(timesteps)
@ -847,48 +847,52 @@ class QwenImagePipeline(): #DiffusionPipeline
offload.set_step_no_for_lora(self.transformer, first_step + i) offload.set_step_no_for_lora(self.transformer, first_step + i)
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t/1000 latent_noise_factor = t/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = latents latents_dtype = latents.dtype
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
if do_true_cfg and joint_pass: # latent_model_input = latents
noise_pred, neg_noise_pred = self.transformer( def denoise(latent_model_input, true_cfg_scale):
hidden_states=latent_model_input, if image_latents is not None:
timestep=timestep / 1000, latent_model_input = torch.cat([latents, image_latents], dim=1)
guidance=guidance, do_true_cfg = true_cfg_scale > 1
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], if do_true_cfg and joint_pass:
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], noise_pred, neg_noise_pred = self.transformer(
img_shapes=img_shapes, hidden_states=latent_model_input,
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], timestep=timestep / 1000,
attention_kwargs=self.attention_kwargs, guidance=guidance, #!!!!
**kwargs encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
) encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
if noise_pred == None: return None img_shapes=img_shapes,
noise_pred = noise_pred[:, : latents.size(1)] txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
neg_noise_pred = neg_noise_pred[:, : latents.size(1)] attention_kwargs=self.attention_kwargs,
else: **kwargs
noise_pred = self.transformer( )
hidden_states=latent_model_input, if noise_pred == None: return None, None
timestep=timestep / 1000, noise_pred = noise_pred[:, : latents.size(1)]
guidance=guidance, neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
encoder_hidden_states_mask_list=[prompt_embeds_mask], else:
encoder_hidden_states_list=[prompt_embeds], neg_noise_pred = None
img_shapes=img_shapes, noise_pred = self.transformer(
txt_seq_lens_list=[txt_seq_lens], hidden_states=latent_model_input,
attention_kwargs=self.attention_kwargs, timestep=timestep / 1000,
**kwargs guidance=guidance,
)[0] encoder_hidden_states_mask_list=[prompt_embeds_mask],
if noise_pred == None: return None encoder_hidden_states_list=[prompt_embeds],
noise_pred = noise_pred[:, : latents.size(1)] img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if noise_pred == None: return None, None
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg: if do_true_cfg:
neg_noise_pred = self.transformer( neg_noise_pred = self.transformer(
@ -902,27 +906,43 @@ class QwenImagePipeline(): #DiffusionPipeline
attention_kwargs=self.attention_kwargs, attention_kwargs=self.attention_kwargs,
**kwargs **kwargs
)[0] )[0]
if neg_noise_pred == None: return None if neg_noise_pred == None: return None, None
neg_noise_pred = neg_noise_pred[:, : latents.size(1)] neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
return noise_pred, neg_noise_pred
def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
if do_true_cfg:
comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
if comb_pred == None: return None
if do_true_cfg: cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
if comb_pred == None: return None noise_pred = comb_pred * (cond_norm / noise_norm)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) return noise_pred
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
neg_noise_pred = None if lanpaint_proc is not None and i<=3:
latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
if latents is None: return None
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None
noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
neg_noise_pred = None
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
noise_pred = None
if image_mask_latents is not None: if image_mask_latents is not None:
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0 if lanpaint_proc is not None:
latent_noise_factor = next_t / 1000 latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor else:
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents latent_noise_factor = next_t / 1000
noisy_image = None # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
noisy_image = None
if latents.dtype != latents_dtype: if latents.dtype != latents_dtype:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():

View File

@ -32,9 +32,10 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
from shared.utils.vace_preprocessor import VaceVideoProcessor from shared.utils.vace_preprocessor import VaceVideoProcessor
from shared.utils.basic_flowmatch import FlowMatchScheduler from shared.utils.basic_flowmatch import FlowMatchScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
from mmgp import safetensors2 from mmgp import safetensors2
from shared.utils.audio_video import save_video
def optimized_scale(positive_flat, negative_flat): def optimized_scale(positive_flat, negative_flat):
@ -93,7 +94,7 @@ class WanAny2V:
shard_fn= None) shard_fn= None)
# base_model_type = "i2v2_2" # base_model_type = "i2v2_2"
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]: if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]:
self.clip = CLIPModel( self.clip = CLIPModel(
dtype=config.clip_dtype, dtype=config.clip_dtype,
device=self.device, device=self.device,
@ -102,7 +103,7 @@ class WanAny2V:
tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
if base_model_type in ["ti2v_2_2"]: if base_model_type in ["ti2v_2_2", "lucy_edit"]:
self.vae_stride = (4, 16, 16) self.vae_stride = (4, 16, 16)
vae_checkpoint = "Wan2.2_VAE.safetensors" vae_checkpoint = "Wan2.2_VAE.safetensors"
vae = Wan2_2_VAE vae = Wan2_2_VAE
@ -190,19 +191,14 @@ class WanAny2V:
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
self.sample_neg_prompt = config.sample_neg_prompt self.sample_neg_prompt = config.sample_neg_prompt
if self.model.config.get("vace_in_dim", None) != None: if hasattr(self.model, "vace_blocks"):
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
self.adapt_vace_model(self.model) self.adapt_vace_model(self.model)
if self.model2 is not None: self.adapt_vace_model(self.model2) if self.model2 is not None: self.adapt_vace_model(self.model2)
if hasattr(self.model, "face_adapter"):
self.adapt_animate_model(self.model)
if self.model2 is not None: self.adapt_animate_model(self.model2)
self.num_timesteps = 1000 self.num_timesteps = 1000
self.use_timestep_transform = True self.use_timestep_transform = True
@ -277,51 +273,6 @@ class WanAny2V:
def vace_latent(self, z, m): def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
from shared.utils.utils import save_image
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img) if return_mask else None
else:
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
else:
canvas_height, canvas_width = image_size
if full_frame:
new_height = canvas_height
new_width = canvas_width
top = left = 0
else:
# if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
# if fill_max and (canvas_width - new_width) < 16:
# new_width = canvas_width
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
return ref_img.to(device), canvas
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = [] image_sizes = []
@ -375,7 +326,7 @@ class WanAny2V:
for k, frame in enumerate(inject_frames): for k, frame in enumerate(inject_frames):
if frame != None: if frame != None:
pos = prepend_count + k pos = prepend_count + k
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
self.background_mask = None self.background_mask = None
@ -386,9 +337,9 @@ class WanAny2V:
if ref_img is not None and not torch.is_tensor(ref_img): if ref_img is not None and not torch.is_tensor(ref_img):
if j==0 and any_background_ref: if j==0 and any_background_ref:
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
else: else:
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device) src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
if self.background_mask != None: if self.background_mask != None:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images return src_video, src_mask, src_ref_images
@ -402,12 +353,26 @@ class WanAny2V:
return torch.cat(ref_vae_latents, dim=1) return torch.cat(ref_vae_latents, dim=1)
def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest')
if nb_frames_unchanged >0:
msk[:, :nb_frames_unchanged] = 1
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1,2)[0]
return msk
def generate(self, def generate(self,
input_prompt, input_prompt,
input_frames= None, input_frames= None,
input_masks = None, input_masks = None,
input_ref_images = None, input_ref_images = None,
input_ref_masks = None,
input_faces = None,
input_video = None, input_video = None,
image_start = None, image_start = None,
image_end = None, image_end = None,
@ -541,14 +506,18 @@ class WanAny2V:
infinitetalk = model_type in ["infinitetalk"] infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"] standin = model_type in ["standin", "vace_standin_14B"]
recam = model_type in ["recam_1.3B"] recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"] ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
lucy_edit= model_type in ["lucy_edit"]
animate= model_type in ["animate"]
start_step_no = 0 start_step_no = 0
ref_images_count = 0 ref_images_count = 0
trim_frames = 0 trim_frames = 0
extended_overlapped_latents = None extended_overlapped_latents = clip_image_start = clip_image_end = None
no_noise_latents_injection = infinitetalk no_noise_latents_injection = infinitetalk
timestep_injection = False timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
extended_input_dim = 0
ref_images_before = False
# image2video # image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False any_end_frame = False
@ -598,17 +567,7 @@ class WanAny2V:
if image_end is not None: if image_end is not None:
img_end_frame = image_end.unsqueeze(1).to(self.device) img_end_frame = image_end.unsqueeze(1).to(self.device)
clip_image_start, clip_image_end = image_start, image_end
if hasattr(self, "clip"):
clip_image_size = self.clip.model.image_size
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start
if model_type == "flf2v_720p":
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([image_start[:, None, :, :]])
else:
clip_context = None
if any_end_frame: if any_end_frame:
enc= torch.concat([ enc= torch.concat([
@ -647,21 +606,62 @@ class WanAny2V:
if infinitetalk: if infinitetalk:
lat_y = self.vae.encode([input_video], VAE_tile_size)[0] lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
# if control_pre_frames_count != pre_frames_count:
lat_y = input_video = None lat_y = input_video = None
kwargs.update({ 'y': y}) kwargs.update({ 'y': y})
if not clip_context is None:
kwargs.update({'clip_fea': clip_context})
# Recam Master # Animate
if recam: if animate:
target_camera = model_mode pose_pixels = input_frames * input_masks
height,width = input_frames.shape[-2:] input_masks = 1. - input_masks
input_frames = input_frames.to(dtype=self.dtype , device=self.device) pose_pixels -= input_masks
source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) save_video(pose_pixels, "pose.mp4")
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
save_video(input_frames, "input_frames.mp4")
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
msk = torch.concat([msk_ref, msk_control], dim=1)
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
lat_y = msk = msk_control = msk_ref = pose_pixels = None
ref_images_before = True
ref_images_count = 1
lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
# Clip image
if hasattr(self, "clip") and clip_image_start is not None:
clip_image_size = self.clip.model.image_size
clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
if model_type == "flf2v_720p":
clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
clip_image_start = clip_image_end = None
kwargs.update({'clip_fea': clip_context})
# Recam Master & Lucy Edit
if recam or lucy_edit:
frame_num, height,width = input_frames.shape[-3:]
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
frame_num = (lat_frames -1) * self.vae_stride[0] + 1
input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
extended_input_dim = 2 if recam else 1
del input_frames del input_frames
if recam:
# Process target camera (recammaster) # Process target camera (recammaster)
target_camera = model_mode
from shared.utils.cammmaster_tools import get_camera_embedding from shared.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera) cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
@ -715,6 +715,8 @@ class WanAny2V:
height, width = input_video.shape[-2:] height, width = input_video.shape[-2:]
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
timestep_injection = True timestep_injection = True
if extended_input_dim > 0:
extended_latents[:, :, :source_latents.shape[2]] = source_latents
# Vace # Vace
if vace : if vace :
@ -722,6 +724,7 @@ class WanAny2V:
input_frames = [u.to(self.device) for u in input_frames] input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks] input_masks = [u.to(self.device) for u in input_masks]
ref_images_before = True
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images) m0 = self.vace_encode_masks(input_masks, input_ref_images)
@ -771,9 +774,9 @@ class WanAny2V:
expand_shape = [batch_size] + [-1] * len(target_shape) expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes # Ropes
if target_camera != None: if extended_input_dim>=2:
shape = list(target_shape[1:]) shape = list(target_shape[1:])
shape[0] *= 2 shape[extended_input_dim-2] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else: else:
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
@ -901,8 +904,8 @@ class WanAny2V:
for zz in z: for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
if target_camera != None: if extended_input_dim > 0:
latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
else: else:
latent_model_input = latents latent_model_input = latents
@ -1030,7 +1033,7 @@ class WanAny2V:
if callback is not None: if callback is not None:
latents_preview = latents latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,:1] if image_outputs: latents_preview= latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
@ -1041,7 +1044,7 @@ class WanAny2V:
if timestep_injection: if timestep_injection:
latents[:, :, :source_latents.shape[2]] = source_latents latents[:, :, :source_latents.shape[2]] = source_latents
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames] if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None: if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone() latent_slice = latents[:, :, return_latent_slice].clone()
@ -1078,4 +1081,12 @@ class WanAny2V:
delattr(model, "vace_blocks") delattr(model, "vace_blocks")
def adapt_animate_model(self, model):
modules_dict= { k: m for k, m in model.named_modules()}
for animate_layer in range(8):
module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
model_layer = animate_layer * 5
target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "face_adapter_fuser_blocks", module )
delattr(model, "face_adapter")

View File

@ -16,6 +16,9 @@ from mmgp.offload import get_cache, clear_caches
from shared.attention import pay_attention from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target from ..multitalk.multitalk_utils import get_attn_map_with_target
from ..animate.motion_encoder import Generator
from ..animate.face_blocks import FaceAdapter, FaceEncoder
from ..animate.model_animate import after_patch_embedding
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -499,6 +502,7 @@ class WanAttentionBlock(nn.Module):
multitalk_masks=None, multitalk_masks=None,
ref_images_count=0, ref_images_count=0,
standin_phase=-1, standin_phase=-1,
motion_vec = None,
): ):
r""" r"""
Args: Args:
@ -616,6 +620,10 @@ class WanAttentionBlock(nn.Module):
x.add_(hint) x.add_(hint)
else: else:
x.add_(hint, alpha= scale) x.add_(hint, alpha= scale)
if motion_vec is not None and self.block_no % 5 == 0:
x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
return x return x
class AudioProjModel(ModelMixin, ConfigMixin): class AudioProjModel(ModelMixin, ConfigMixin):
@ -898,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin):
norm_input_visual=True, norm_input_visual=True,
norm_output_audio=True, norm_output_audio=True,
standin= False, standin= False,
motion_encoder_dim=0,
): ):
super().__init__() super().__init__()
@ -922,7 +931,7 @@ class WanModel(ModelMixin, ConfigMixin):
self.flag_causal_attention = False self.flag_causal_attention = False
self.block_mask = None self.block_mask = None
self.inject_sample_info = inject_sample_info self.inject_sample_info = inject_sample_info
self.motion_encoder_dim = motion_encoder_dim
self.norm_output_audio = norm_output_audio self.norm_output_audio = norm_output_audio
self.audio_window = audio_window self.audio_window = audio_window
self.intermediate_dim = intermediate_dim self.intermediate_dim = intermediate_dim
@ -930,6 +939,7 @@ class WanModel(ModelMixin, ConfigMixin):
multitalk = multitalk_output_dim > 0 multitalk = multitalk_output_dim > 0
self.multitalk = multitalk self.multitalk = multitalk
animate = motion_encoder_dim > 0
# embeddings # embeddings
self.patch_embedding = nn.Conv3d( self.patch_embedding = nn.Conv3d(
@ -1027,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin):
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
if animate:
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding] layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype target_dype= dtype
@ -1208,6 +1237,9 @@ class WanModel(ModelMixin, ConfigMixin):
ref_images_count = 0, ref_images_count = 0,
standin_freqs = None, standin_freqs = None,
standin_ref = None, standin_ref = None,
pose_latents=None,
face_pixel_values=None,
): ):
# patch_dtype = self.patch_embedding.weight.dtype # patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype modulation_dtype = self.time_projection[1].weight.dtype
@ -1240,9 +1272,18 @@ class WanModel(ModelMixin, ConfigMixin):
if bz > 1: y = y.expand(bz, -1, -1, -1, -1) if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1) x = torch.cat([x, y], dim=1)
# embeddings # embeddings
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:] grid_sizes = x.shape[2:]
x_list[i] = x
y = None
motion_vec_list = []
for i, x in enumerate(x_list):
# animate embeddings
motion_vec = None
if pose_latents is not None:
x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk: if chipmunk:
x = x.unsqueeze(-1) x = x.unsqueeze(-1)
x_og_shape = x.shape x_og_shape = x.shape
@ -1250,7 +1291,7 @@ class WanModel(ModelMixin, ConfigMixin):
else: else:
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
x_list[i] = x x_list[i] = x
x, y = None, None x = None
block_mask = None block_mask = None
@ -1450,9 +1491,9 @@ class WanModel(ModelMixin, ConfigMixin):
continue continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else: else:
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
if should_calc: if should_calc:
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
del x del x
context = hints = audio_embedding = None context = hints = audio_embedding = None

View File

@ -3,10 +3,10 @@ import numpy as np
import gradio as gr import gradio as gr
def test_class_i2v(base_model_type): def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
def text_oneframe_overlap(base_model_type): def text_oneframe_overlap(base_model_type):
return test_class_i2v(base_model_type) and not test_multitalk(base_model_type) return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
def test_class_1_3B(base_model_type): def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@ -17,6 +17,8 @@ def test_multitalk(base_model_type):
def test_standin(base_model_type): def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"] return base_model_type in ["standin", "vace_standin_14B"]
def test_wan_5B(base_model_type):
return base_model_type in ["ti2v_2_2", "lucy_edit"]
class family_handler(): class family_handler():
@staticmethod @staticmethod
@ -36,7 +38,7 @@ class family_handler():
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
elif base_model_type in ["i2v_2_2"]: elif base_model_type in ["i2v_2_2"]:
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
elif base_model_type in ["ti2v_2_2"]: elif test_wan_5B(base_model_type):
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
else: # i2v else: # i2v
@ -83,11 +85,13 @@ class family_handler():
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class extra_model_def["vace_class"] = vace_class
if test_multitalk(base_model_type): if base_model_type in ["animate"]:
fps = 30
elif test_multitalk(base_model_type):
fps = 25 fps = 25
elif base_model_type in ["fantasy"]: elif base_model_type in ["fantasy"]:
fps = 23 fps = 23
elif base_model_type in ["ti2v_2_2"]: elif test_wan_5B(base_model_type):
fps = 24 fps = 24
else: else:
fps = 16 fps = 16
@ -100,14 +104,14 @@ class family_handler():
extra_model_def.update({ extra_model_def.update({
"frames_minimum" : frames_minimum, "frames_minimum" : frames_minimum,
"frames_steps" : frames_steps, "frames_steps" : frames_steps,
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2",
"multiple_submodels" : multiple_submodels, "multiple_submodels" : multiple_submodels,
"guidance_max_phases" : 3, "guidance_max_phases" : 3,
"skip_layer_guidance" : True, "skip_layer_guidance" : True,
"cfg_zero" : True, "cfg_zero" : True,
"cfg_star" : True, "cfg_star" : True,
"adaptive_projected_guidance" : True, "adaptive_projected_guidance" : True,
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True, "mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"convert_image_guide_to_video" : True, "convert_image_guide_to_video" : True,
@ -146,6 +150,34 @@ class family_handler():
} }
# extra_model_def["at_least_one_image_ref_needed"] = True # extra_model_def["at_least_one_image_ref_needed"] = True
if base_model_type in ["lucy_edit"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["guide_preprocessing"] = {
"selection": ["UV"],
"labels" : { "UV": "Control Video"},
"visible": False,
}
if base_model_type in ["animate"]:
extra_model_def["guide_custom_choices"] = {
"choices":[
("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"),
("Replace Person in Control Video Person in Reference Image", "PVBAI"),
],
"default": "KI",
"letters_filter": "PVBXAKI",
"label": "Type of Process",
"show_label" : False,
}
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
extra_model_def["background_ref_outpainted"] = False
if vace_class: if vace_class:
extra_model_def["guide_preprocessing"] = { extra_model_def["guide_preprocessing"] = {
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
@ -157,16 +189,19 @@ class family_handler():
extra_model_def["image_ref_choices"] = { extra_model_def["image_ref_choices"] = {
"choices": [("None", ""), "choices": [("None", ""),
("Inject only People / Objects", "I"), ("People / Objects", "I"),
("Inject Landscape and then People / Objects", "KI"), ("Landscape followed by People / Objects (if any)", "KI"),
("Inject Frames and then People / Objects", "FI"), ("Positioned Frames followed by People / Objects (if any)", "FI"),
], ],
"letters_filter": "KFI", "letters_filter": "KFI",
} }
extra_model_def["lock_image_refs_ratios"] = True extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["pad_guide_video"] = True
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
if base_model_type in ["standin"]: if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True extra_model_def["lock_image_refs_ratios"] = True
@ -209,10 +244,12 @@ class family_handler():
"visible" : False, "visible" : False,
} }
if vace_class or base_model_type in ["infinitetalk"]: if vace_class or base_model_type in ["infinitetalk", "animate"]:
image_prompt_types_allowed = "TVL" image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]: elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL" image_prompt_types_allowed = "TSVL"
elif base_model_type in ["lucy_edit"]:
image_prompt_types_allowed = "TVL"
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
image_prompt_types_allowed = "SVL" image_prompt_types_allowed = "SVL"
elif i2v: elif i2v:
@ -234,8 +271,8 @@ class family_handler():
def query_supported_types(): def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B", "recam_1.3B", "animate",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod @staticmethod
@ -265,11 +302,12 @@ class family_handler():
@staticmethod @staticmethod
def get_vae_block_size(base_model_type): def get_vae_block_size(base_model_type):
return 32 if base_model_type == "ti2v_2_2" else 16 return 32 if test_wan_5B(base_model_type) else 16
@staticmethod @staticmethod
def get_rgb_factors(base_model_type ): def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors from shared.RGB_factors import get_rgb_factors
if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias
@ -283,7 +321,7 @@ class family_handler():
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
}] }]
if base_model_type == "ti2v_2_2": if test_wan_5B(base_model_type):
download_def += [ { download_def += [ {
"repoId" : "DeepBeepMeep/Wan2.2", "repoId" : "DeepBeepMeep/Wan2.2",
"sourceFolderList" : [""], "sourceFolderList" : [""],
@ -377,7 +415,7 @@ class family_handler():
ui_defaults.update({ ui_defaults.update({
"sample_solver": "unipc", "sample_solver": "unipc",
}) })
if test_class_i2v(base_model_type): if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
ui_defaults["image_prompt_type"] = "S" ui_defaults["image_prompt_type"] = "S"
if base_model_type in ["fantasy"]: if base_model_type in ["fantasy"]:
@ -434,10 +472,15 @@ class family_handler():
"image_prompt_type": "T", "image_prompt_type": "T",
}) })
if base_model_type in ["recam_1.3B"]: if base_model_type in ["recam_1.3B", "lucy_edit"]:
ui_defaults.update({ ui_defaults.update({
"video_prompt_type": "UV", "video_prompt_type": "UV",
}) })
elif base_model_type in ["animate"]:
ui_defaults.update({
"video_prompt_type": "PVBXAKI",
"mask_expand": 20,
})
if text_oneframe_overlap(base_model_type): if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1 ui_defaults["sliding_window_overlap"] = 1

View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
"""
Convert a Flux model from Diffusers (folder or single-file) into the original
single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
Output : /path/to/flux1-your-model.safetensors (transformer only)
Usage:
python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
# optional quantization:
# --fp8 (float8_e4m3fn, simple)
# --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
"""
import argparse
import json
from pathlib import Path
from collections import OrderedDict
import torch
from safetensors import safe_open
import safetensors.torch
from tqdm import tqdm
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("diffusers_path", type=str,
help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
ap.add_argument("output_path", type=str,
help="Output .safetensors path for the Flux transformer.")
ap.add_argument("--fp8", action="store_true",
help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
ap.add_argument("--fp8-scaled", action="store_true",
help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
return ap.parse_args()
# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
DIFFUSERS_MAP = {
# global embeds
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
# dual-stream (image/text) blocks
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
# single-stream blocks
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
# final
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
# these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
class DiffusersSource:
"""
Uniform interface over:
1) Folder with index JSON + shards
2) Folder with exactly one .safetensors (no index)
3) Single .safetensors file
Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
"""
POSSIBLE_PREFIXES = ["", "model."] # try in this order
def __init__(self, path: Path):
p = Path(path)
if p.is_dir():
# use 'transformer' subfolder if present
if (p / "transformer").is_dir():
p = p / "transformer"
self._init_from_dir(p)
elif p.is_file() and p.suffix == ".safetensors":
self._init_from_single_file(p)
else:
raise FileNotFoundError(f"Invalid path: {p}")
# ---------- common helpers ----------
@staticmethod
def _strip_prefix(k: str) -> str:
return k[6:] if k.startswith("model.") else k
def _resolve(self, want: str):
"""
Return the actual stored key matching `want` by trying known prefixes.
"""
for pref in self.POSSIBLE_PREFIXES:
k = pref + want
if k in self._all_keys:
return k
return None
def has(self, want: str) -> bool:
return self._resolve(want) is not None
def get(self, want: str) -> torch.Tensor:
real_key = self._resolve(want)
if real_key is None:
raise KeyError(f"Missing key: {want}")
return self._get_by_real_key(real_key).to("cpu")
@property
def base_keys(self):
# keys without 'model.' prefix for scanning
return [self._strip_prefix(k) for k in self._all_keys]
# ---------- modes ----------
def _init_from_single_file(self, file_path: Path):
self._mode = "single"
self._file = file_path
self._handle = safe_open(file_path, framework="pt", device="cpu")
self._all_keys = list(self._handle.keys())
def _get_by_real_key(real_key: str):
return self._handle.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
def _init_from_dir(self, dpath: Path):
index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
if index_json.exists():
with open(index_json, "r", encoding="utf-8") as f:
index = json.load(f)
weight_map = index["weight_map"] # full mapping
self._mode = "sharded"
self._dpath = dpath
self._weight_map = {k: dpath / v for k, v in weight_map.items()}
self._all_keys = list(self._weight_map.keys())
self._open_handles = {}
def _get_by_real_key(real_key: str):
fpath = self._weight_map[real_key]
h = self._open_handles.get(fpath)
if h is None:
h = safe_open(fpath, framework="pt", device="cpu")
self._open_handles[fpath] = h
return h.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
return
# no index: try exactly one safetensors in folder
files = sorted(dpath.glob("*.safetensors"))
if len(files) != 1:
raise FileNotFoundError(
f"No index found and {dpath} does not contain exactly one .safetensors file."
)
self._init_from_single_file(files[0])
def main():
args = parse_args()
src = DiffusersSource(Path(args.diffusers_path))
# Count blocks by scanning base keys (with any 'model.' prefix removed)
num_dual = 0
num_single = 0
for k in src.base_keys:
if k.startswith("transformer_blocks."):
try:
i = int(k.split(".")[1])
num_dual = max(num_dual, i + 1)
except Exception:
pass
elif k.startswith("single_transformer_blocks."):
try:
i = int(k.split(".")[1])
num_single = max(num_single, i + 1)
except Exception:
pass
print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
# Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
shift, scale = vec.chunk(2, dim=0)
return torch.cat([scale, shift], dim=0)
orig = {}
# Per-block (dual)
for b in range(num_dual):
prefix = f"transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("double_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Per-block (single)
for b in range(num_single):
prefix = f"single_transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("single_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Globals (non-block)
for okey, dvals in DIFFUSERS_MAP.items():
if okey.startswith(("double_blocks.", "single_blocks.")):
continue
dkeys = dvals
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey] = src.get(dkeys[0])
else:
orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
if "final_layer.adaLN_modulation.1.weight" in orig:
orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.weight"]
)
if "final_layer.adaLN_modulation.1.bias" in orig:
orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.bias"]
)
# Optional FP8 variants (experimental; not required for ComfyUI/BFL)
if args.fp8 or args.fp8_scaled:
dtype = torch.float8_e4m3fn # noqa
minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
def stochastic_round_to(t):
t = t.float().clamp(minv, maxv)
lower = torch.floor(t * 256) / 256
upper = torch.ceil(t * 256) / 256
prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
rnd = torch.rand_like(t)
out = torch.where(rnd < prob, upper, lower)
return out.to(dtype)
def scale_to_8bit(weight, target_max=416.0):
absmax = weight.abs().max()
scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
scaled = (weight / scale).clamp(minv, maxv).to(dtype)
return scaled, scale
scales = {}
for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
t = orig[k]
if args.fp8:
orig[k] = stochastic_round_to(t)
else:
if k.endswith(".weight") and t.dim() == 2:
qt, s = scale_to_8bit(t)
orig[k] = qt
scales[k[:-len(".weight")] + ".scale_weight"] = s
else:
orig[k] = t.clamp(minv, maxv).to(dtype)
if args.fp8_scaled:
orig.update(scales)
orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
else:
# Default: save in bfloat16
for k in list(orig.keys()):
orig[k] = orig[k].to(torch.bfloat16).cpu()
out_path = Path(args.output_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
meta = OrderedDict()
meta["format"] = "pt"
meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
print(f"Saving transformer to: {out_path}")
safetensors.torch.save_file(orig, str(out_path), metadata=meta)
print("Done.")
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,240 @@
import torch
from .utils import *
from functools import partial
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor=8):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class LanPaint():
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
self.n_steps = NSteps
self.chara_lamb = Lambda
self.IS_FLUX = IS_FLUX
self.IS_FLOW = IS_FLOW
self.step_size = StepSize
self.friction = Friction
self.chara_beta = Beta
self.img_dim_size = None
def add_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (None,) * (self.img_dim_size-1)
return array[index]
def remove_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (0,) * (self.img_dim_size-1)
return array[index]
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.img_dim_size = len(x.shape)
self.latent_image = latent_image
self.noise = noise
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
if IS_FLUX:
cfg_BIG = 1.0
def double_denoise(latents, t):
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
return predict_std, predict_big
if len(sigma.shape) == 0:
sigma = torch.tensor([sigma.item()])
latent_mask = 1 - latent_mask
if IS_FLUX or IS_FLOW:
Flow_t = sigma
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
VE_Sigma = Flow_t / (1 - Flow_t)
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
else:
VE_Sigma = sigma
abt = 1/( 1+VE_Sigma**2 )
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
# VE_Sigma, abt, Flow_t = current_times
current_times = (VE_Sigma, abt, Flow_t)
step_size = self.step_size * (1 - abt)
step_size = self.add_none_dims(step_size)
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
# This is the replace step
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
x = x * (1 - latent_mask) + noisy_image * latent_mask
if IS_FLUX or IS_FLOW:
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations Start ###############
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
args = None
for i in range(n_steps):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations End ###############
# out is x_0
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
# out = out * (1-latent_mask) + self.latent_image * latent_mask
# return out
return x
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
lamb = self.chara_lamb
if self.IS_FLUX or self.IS_FLOW:
# compute t for flow model, with a small epsilon compensating for numerical error.
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
if x_0 is None: return None
else:
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
if x_0 is None: return None
score_x = -(x_t - x_0)
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
return score_x * (1 - mask) + score_y * mask
def sigma_x(self, abt):
# the time scale for the x_t update
return abt**0
def sigma_y(self, abt):
beta = self.chara_beta * abt ** 0
return beta
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
# prepare the step size and time parameters
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
# print('mask',mask.device)
if torch.mean(dtx) <= 0.:
return x_t, args
# -------------------------------------------------------------------------
# Compute the Langevin dynamics update in variance perserving notation
# -------------------------------------------------------------------------
#x0 = self.x0_evalutation(x_t, score, sigma, args)
#C = abt**0.5 * x0 / (1-abt)
A = A_x * (1-mask) + A_y * mask
D = D_x * (1-mask) + D_y * mask
dt = dtx * (1-mask) + dty * mask
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
dtype = x_t.dtype
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
x_t, v = osc.dynamics(x_t, v, dt )
x_t = x_t.to(dtype)
v = v.to(dtype)
return x_t, v
if args is None:
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
v, C = args
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C = C_new
return x_t, (v, C)
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
# -------------------------------------------------------------------------
# Unpack current times parameters (sigma and abt)
sigma, abt, flow_t = current_times
sigma = self.add_none_dims(sigma)
abt = self.add_none_dims(abt)
# Compute time step (dtx, dty) for x and y branches.
dtx = 2 * step_size * sigma_x
dty = 2 * step_size * sigma_y
# -------------------------------------------------------------------------
# Define friction parameter Gamma_hat for each branch.
# Using dtx**0 provides a tensor of the proper device/dtype.
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
# adjust dt to match denoise-addnoise steps sizes
Gamma_hat_x /= 2.
Gamma_hat_y /= 2.
A_t_x = (1) / ( 1 - abt ) * dtx / 2
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
A_x = A_t_x / (dtx/2)
A_y = A_t_y / (dty/2)
Gamma_x = Gamma_hat_x / (dtx/2)
Gamma_y = Gamma_hat_y / (dty/2)
#D_x = (2 * (1 + sigma**2) )**0.5
#D_y = (2 * (1 + sigma**2) )**0.5
D_x = (2 * abt**0 )**0.5
D_y = (2 * abt**0 )**0.5
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
def x0_evalutation(self, x_t, score, sigma, args):
x0 = x_t + score(x_t)
return x0

301
shared/inpainting/utils.py Normal file
View File

@ -0,0 +1,301 @@
import torch
def epxm1_x(x):
# Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
result = torch.special.expm1(x) / x
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x) < 1e-2
result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
return result
def epxm1mx_x2(x):
# Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x) / x**2
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**2) < 1e-2
result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
return result
def expm1mxmhx2_x3(x):
# Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**3) < 1e-2
result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
return result
def exp_1mcosh_GD(gamma_t, delta):
"""
Compute e^(-Γt) * (1 - cosh(ΓtΔ))/ ( (Γt)**2 Δ )
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) )
numerator = torch.where(is_positive, numerator_pos, numerator_neg)
result = numerator / (delta * gamma_t**2 )
# Handle NaN/inf cases
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Handle numerical instability for small delta
mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
result = torch.where(mask, taylor, result)
return result
def exp_sinh_GsqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / (ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
denominator_pos = gamma_t_sqrt_delta
result_pos = numerator_pos / gamma_t_sqrt_delta
result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
# Taylor expansion for small gamma_t_sqrt_delta
mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
result_pos = torch.where(mask, taylor, result_pos)
# Handle negative delta
result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
result = torch.where(is_positive, result_pos, result_neg)
return result
def exp_cosh(gamma_t, delta):
"""
Compute e^(-Γt) * cosh(ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
return result
def exp_sinh_sqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / Δ
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
result = gamma_t * exp_sinh_GsqrtD_result
return result
def zeta1(gamma_t, delta):
# Compute hyperbolic terms and exponential
half_gamma_t = gamma_t / 2
exp_cosh_term = exp_cosh(half_gamma_t, delta)
exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
# Main computation
numerator = 1 - (exp_cosh_term + exp_sinh_term)
denominator = gamma_t * (1 - delta) / 4
result = 1 - numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small x (similar to your epxm1Dx approach)
mask = torch.abs(denominator) < 5e-3
term1 = epxm1_x(-gamma_t)
term2 = epxm1mx_x2(-gamma_t)
term3 = expm1mxmhx2_x3(-gamma_t)
taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
result = torch.where(mask, taylor, result)
return result
def exp_cosh_minus_terms(gamma_t, delta):
"""
Compute E^(-) * (Cosh[] - 1 - (Cosh[Δ] - 1)/Δ) / ((1 - Δ))
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_term = torch.exp(-gamma_t)
# Compute individual terms
exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
#exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
# Main computation
numerator = exp_cosh_term - exp_cosh_delta_term
denominator = gamma_t * (1 - delta)
result = numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small gamma_t and delta near 1
mask = (torch.abs(denominator) < 1e-1)
exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
taylor = (
gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0)
- denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
)
result = torch.where(mask, taylor, result)
return result
def zeta2(gamma_t, delta):
half_gamma_t = gamma_t / 2
return exp_sinh_GsqrtD(half_gamma_t, delta)
def sig11(gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def Zcoefs(gamma_t, delta):
Zeta1 = zeta1(gamma_t, delta)
Zeta2 = zeta2(gamma_t, delta)
sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
amplitude = torch.sqrt(sq_total)
Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5
#cterm = exp_cosh_minus_terms(gamma_t, delta)
#sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
#Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
def Zcoefs_asymp(gamma_t, delta):
A_t = (gamma_t * (1 - delta) )/4
return epxm1_x(- 2 * A_t)
class StochasticHarmonicOscillator:
"""
Simulates a stochastic harmonic oscillator governed by the equations:
dy(t) = q(t) dt
dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
Also define v(t) = q(t) / Γ, which is numerically more stable.
Where:
y(t) - Position variable
q(t) - Velocity variable
Γ - Damping coefficient
A - Harmonic potential strength
C - Constant force term
D - Noise amplitude
dw(t) - Wiener process (Brownian motion)
"""
def __init__(self, Gamma, A, C, D):
self.Gamma = Gamma
self.A = A
self.C = C
self.D = D
self.Delta = 1 - 4 * A / Gamma
def sig11(self, gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def sig22(self, gamma_t, delta):
return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta)
def dynamics(self, y0, v0, t):
"""
Calculates the position and velocity variables at time t.
Parameters:
y0 (float): Initial position
v0 (float): Initial velocity v(0) = q(0) / Γ
t (float): Time at which to evaluate the dynamics
Returns:
tuple: (y(t), v(t))
"""
dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
Delta = self.Delta + dummyzero
Gamma_hat = self.Gamma * t + dummyzero
A = self.A + dummyzero
C = self.C + dummyzero
D = self.D + dummyzero
Gamma = self.Gamma + dummyzero
zeta_1 = zeta1( Gamma_hat, Delta)
zeta_2 = zeta2( Gamma_hat, Delta)
EE = 1 - Gamma_hat * zeta_2
if v0 is None:
v0 = torch.randn_like(y0) * D / 2 ** 0.5
#v0 = (C - A * y0)/Gamma**0.5
# Calculate mean position and velocity
term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
y_mean = term1 + y0
v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
# sample new position and velocity with multivariate normal distribution
batch_shape = y0.shape
cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
cov_matrix[..., 0, 0] = cov_yy
cov_matrix[..., 0, 1] = cov_yv
cov_matrix[..., 1, 0] = cov_yv # symmetric
cov_matrix[..., 1, 1] = cov_vv
# Compute the Cholesky decomposition to get scale_tril
#scale_tril = torch.linalg.cholesky(cov_matrix)
scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
tol = 1e-8
cov_yy = torch.clamp( cov_yy, min = tol )
sd_yy = torch.sqrt( cov_yy )
inv_sd_yy = 1/(sd_yy)
scale_tril[..., 0, 0] = sd_yy
scale_tril[..., 0, 1] = 0.
scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
# check if it matches torch.linalg.
#assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
# Sample correlated noise from multivariate normal
mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
mean[..., 0] = y_mean
mean[..., 1] = v_mean
new_yv = torch.distributions.MultivariateNormal(
loc=mean,
scale_tril=scale_tril
).sample()
return new_yv[...,0], new_yv[...,1]

View File

@ -232,6 +232,9 @@ def save_video(tensor,
retry=5): retry=5):
"""Save tensor as video with configurable codec and container options.""" """Save tensor as video with configurable codec and container options."""
if torch.is_tensor(tensor) and len(tensor.shape) == 4:
tensor = tensor.unsqueeze(0)
suffix = f'.{container}' suffix = f'.{container}'
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
if not cache_file.endswith(suffix): if not cache_file.endswith(suffix):

110
shared/utils/download.py Normal file
View File

@ -0,0 +1,110 @@
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook

View File

@ -1,4 +1,3 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os import os
import os.path as osp import os.path as osp
@ -176,8 +175,9 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image): def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
def convert_tensor_to_image(t, frame_no = -1): def convert_tensor_to_image(t, frame_no = 0):
t = t[:, frame_no] if frame_no >= 0 else t if len(t.shape) == 4:
t = t[:, frame_no]
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1): def save_image(tensor_image, name, frame_no = -1):
@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width return image, new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ): def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
if rm_background: if rm_background:
session = new_session() session = new_session()
output_list =[] output_list =[]
output_mask_list =[]
for i, img in enumerate(img_list): for i, img in enumerate(img_list):
width, height = img.size width, height = img.size
resized_mask = None
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
if outpainting_dims is not None: if outpainting_dims is not None and background_ref_outpainted:
resized_image =img resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height): elif img.size != (budget_width, budget_height):
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
else: else:
@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
return output_list output_mask_list.append(resized_mask)
return output_list, output_mask_list
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
from shared.utils.utils import save_image
inpaint_color = canvas_tf_bg / 127.5 - 1
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
def str2bool(v): ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
""" canvas = torch.zeros_like(ref_img) if return_mask else None
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
else: else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)') if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
import sys, time else:
canvas_height, canvas_width = image_size
# Global variables to track download progress if full_frame:
_start_time = None new_height = canvas_height
_last_time = None new_width = canvas_width
_last_downloaded = 0 top = left = 0
_speed_history = [] else:
_update_interval = 0.5 # Update speed every 0.5 seconds # if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
def progress_hook(block_num, block_size, total_size, filename=None): # if fill_max and (canvas_width - new_width) < 16:
""" # new_width = canvas_width
Simple progress bar hook for urlretrieve scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
Args: new_width = int(ref_width * scale)
block_num: Number of blocks downloaded so far top = (canvas_height - new_height) // 2
block_size: Size of each block in bytes left = (canvas_width - new_width) // 2
total_size: Total size of the file in bytes ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
filename: Name of the file being downloaded (optional) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
""" if outpainting_dims != None:
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
current_time = time.time() else:
downloaded = block_num * block_size canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
# Initialize timing on first call ref_img = canvas
if _start_time is None or block_num == 0: canvas = None
_start_time = current_time if return_mask:
_last_time = current_time if outpainting_dims != None:
_last_downloaded = 0 canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
_speed_history = [] canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
# Calculate download speed only at specified intervals canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
speed = 0 canvas[:, :, top:top + new_height, left:left + new_width] = 0
if current_time - _last_time >= _update_interval: canvas = canvas.to(device)
if _last_time > 0: if return_image:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) return convert_tensor_to_image(ref_img), canvas
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing return ref_img.to(device), canvas
if len(_speed_history) > 5:
_speed_history.pop(0) def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
# Average the recent speeds for smoother display src_videos, src_masks = [], []
speed = sum(_speed_history) / len(_speed_history) inpaint_color = guide_inpaint_color/127.5 - 1
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
_last_time = current_time for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
_last_downloaded = downloaded src_video = src_mask = None
elif _speed_history: if cur_video_guide is not None:
# Use the last calculated average speed src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
speed = sum(_speed_history) / len(_speed_history) if cur_video_mask is not None and any_mask:
# Format file sizes and speed src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
def format_bytes(bytes_val): if pre_video_guide is not None and not extract_guide_from_window_start:
for unit in ['B', 'KB', 'MB', 'GB']: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if bytes_val < 1024: if any_mask:
return f"{bytes_val:.1f}{unit}" src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
bytes_val /= 1024 if src_video is None:
return f"{bytes_val:.1f}TB" if any_guide_padding:
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
file_display = filename if filename else "Unknown file" if any_mask:
src_mask = torch.zeros_like(src_video[0:1])
if total_size <= 0: elif src_video.shape[1] < current_video_length:
# If total size is unknown, show downloaded bytes if any_guide_padding:
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" if cur_video_mask is not None and any_mask:
# Clear any trailing characters by padding with spaces src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
sys.stdout.write(line.ljust(80)) else:
sys.stdout.flush() new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
return src_video = src_video[:, :new_num_frames]
if any_mask:
downloaded = block_num * block_size src_mask = src_mask[:, :new_num_frames]
percent = min(100, (downloaded / total_size) * 100)
for k, keep in enumerate(keep_video_guide_frames):
# Create progress bar (40 characters wide to leave room for other info) if not keep:
bar_length = 40 pos = prepend_count + k
filled = int(bar_length * percent / 100) src_video[:, pos:pos+1] = inpaint_color
bar = '' * filled + '' * (bar_length - filled) src_mask[:, pos:pos+1] = 1
# Format file sizes and speed for k, frame in enumerate(inject_frames):
def format_bytes(bytes_val): if frame != None:
for unit in ['B', 'KB', 'MB', 'GB']: pos = prepend_count + k
if bytes_val < 1024: src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024 src_videos.append(src_video)
return f"{bytes_val:.1f}TB" src_masks.append(src_mask)
return src_videos, src_masks
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook

222
wgp.py
View File

@ -394,7 +394,7 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
if "F" in video_prompt_type: if "F" in video_prompt_type:
if len(frames_positions.strip()) > 0: if len(frames_positions.strip()) > 0:
positions = frames_positions.split(" ") positions = frames_positions.replace(","," ").split(" ")
for pos_str in positions: for pos_str in positions:
if not pos_str in ["L", "l"] and len(pos_str)>0: if not pos_str in ["L", "l"] and len(pos_str)>0:
if not is_integer(pos_str): if not is_integer(pos_str):
@ -2528,7 +2528,7 @@ def download_models(model_filename = None, model_type= None, module_type = False
from urllib.request import urlretrieve from urllib.request import urlretrieve
from shared.utils.utils import create_progress_hook from shared.utils.download import create_progress_hook
shared_def = { shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1", "repoId" : "DeepBeepMeep/Wan2.1",
@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva
input_mask = convert_tensor_to_image(full_frame) input_mask = convert_tensor_to_image(full_frame)
return input_image, input_mask return input_image, input_mask
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
if not input_video_path or max_frames <= 0:
return None, None
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
any_mask = input_mask_path != None
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if len(video) == 0: return None
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
frame_height, frame_width, _ = video[0].shape
num_frames = min(len(video), len(mask_video))
if num_frames == 0: return None
video, mask_video = video[:num_frames], mask_video[:num_frames]
from preprocessing.face_preprocessor import FaceProcessor
face_processor = FaceProcessor()
face_list = []
for frame_idx in range(num_frames):
frame = video[frame_idx].cpu().numpy()
# video[frame_idx] = None
if any_mask:
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy())
# mask_video[frame_idx] = None
if (frame_width, frame_height) != mask.size:
mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS)
mask = np.array(mask)
alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
alpha_mask[mask > 127] = 1
frame = frame * alpha_mask
frame = Image.fromarray(frame)
face = face_processor.process(frame, resize_to=size, face_crop_scale = 1)
face_list.append(face)
face_processor = None
gc.collect()
torch.cuda.empty_cache()
face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w
if pad_frames > 0:
face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2)
if args.save_masks:
from preprocessing.dwpose.pose import save_one_video
saved_faces_frames = [np.array(face) for face in face_list ]
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
return face_tensor
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
@ -3742,6 +3796,12 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
box = [xmin, ymin, xmax, ymax] box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box] box = [int(x) for x in box]
return box return box
inpaint_color = int(inpaint_color)
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
if not input_video_path or max_frames <= 0: if not input_video_path or max_frames <= 0:
return None, None return None, None
@ -3909,6 +3969,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
preproc_outside = None preproc_outside = None
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if pad_frames > 0:
masked_frames = masked_frames[0] * pad_frames + masked_frames
if any_mask: masked_frames = masks[0] * pad_frames + masks
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
@ -4646,7 +4709,8 @@ def generate_video(
current_video_length = video_length current_video_length = video_length
# VAE Tiling # VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
i2v = test_class_i2v(model_type) i2v = test_class_i2v(model_type)
diffusion_forcing = "diffusion_forcing" in model_filename diffusion_forcing = "diffusion_forcing" in model_filename
t2v = base_model_type in ["t2v"] t2v = base_model_type in ["t2v"]
@ -4662,6 +4726,7 @@ def generate_video(
multitalk = model_def.get("multitalk_class", False) multitalk = model_def.get("multitalk_class", False)
standin = model_def.get("standin_class", False) standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"] infinitetalk = base_model_type in ["infinitetalk"]
animate = base_model_type in ["animate"]
if "B" in audio_prompt_type or "X" in audio_prompt_type: if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations from models.wan.multitalk.multitalk import parse_speakers_locations
@ -4822,7 +4887,6 @@ def generate_video(
repeat_no = 0 repeat_no = 0
extra_generation = 0 extra_generation = 0
initial_total_windows = 0 initial_total_windows = 0
discard_last_frames = sliding_window_discard_last_frames discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length default_requested_frames_to_generate = current_video_length
if sliding_window: if sliding_window:
@ -4843,7 +4907,7 @@ def generate_video(
if repeat_no >= total_generation: break if repeat_no >= total_generation: break
repeat_no +=1 repeat_no +=1
gen["repeat_no"] = repeat_no gen["repeat_no"] = repeat_no
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
prefix_video = pre_video_frame = None prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@ -4899,7 +4963,7 @@ def generate_video(
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
src_ref_images = image_refs src_ref_images, src_ref_masks = image_refs, None
image_start_tensor = image_end_tensor = None image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None): if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None: if image_start is not None:
@ -4943,16 +5007,52 @@ def generate_video(
from models.wan.multitalk.multitalk import get_window_audio_embeddings from models.wan.multitalk.multitalk import get_window_audio_embeddings
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
if vace:
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0:
frames_positions_list = []
if frames_positions is not None and len(frames_positions)> 0:
positions = frames_positions.replace(","," ").split(" ")
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count)
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
joker_used = False
project_window_no = 1
for pos in positions :
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
frames_positions_list.append(cur_end_pos)
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
else:
frames_positions_list.append(int(pos)-1 + alignment_shift)
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
if nb_frames_positions > 0:
frames_to_inject = [None] * (max(frames_positions_list) + 1)
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if video_guide is not None: if video_guide is not None:
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0: if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
guide_frames_extract_count = len(keep_frames_parsed)
if vace: if "B" in video_prompt_type:
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
if src_faces is not None and src_faces.shape[1] < current_video_length:
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
if vace or animate:
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
context_scale = [ control_net_weight] context_scale = [ control_net_weight]
if "V" in video_prompt_type: if "V" in video_prompt_type:
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
@ -4971,10 +5071,10 @@ def generate_video(
if preprocess_type2 is not None: if preprocess_type2 is not None:
context_scale = [ control_net_weight /2, control_net_weight2 /2] context_scale = [ control_net_weight /2, control_net_weight2 /2]
send_cmd("progress", [0, get_latest_status(state, status_info)]) send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127 inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
if preprocess_type2 != None: if preprocess_type2 != None:
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
if video_guide_processed != None: if video_guide_processed != None:
if sample_fit_canvas != None: if sample_fit_canvas != None:
@ -4985,6 +5085,36 @@ def generate_video(
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
if video_mask_processed != None: if video_mask_processed != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
any_mask = True
any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
[video_mask_processed, video_mask_processed2],
pre_video_guide, image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
src_video, src_video2 = src_videos
src_mask, src_mask2 = src_masks
if src_video is None:
abort = True
break
if src_faces is not None:
if src_faces.shape[1] < src_video.shape[1]:
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
src_faces = src_faces[:, :src_video.shape[1]]
if args.save_masks:
save_video( src_video, "masked_frames.mp4", fps)
if src_video2 is not None:
save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask:
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
elif ltxv: elif ltxv:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type] status_info = "Extracting " + processes_names[preprocess_type]
@ -5023,7 +5153,7 @@ def generate_video(
sample_fit_canvas = None sample_fit_canvas = None
else: # video to video else: # video to video
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps) video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
if video_guide_processed is None: if video_guide_processed is None:
src_video = pre_video_guide src_video = pre_video_guide
else: else:
@ -5043,29 +5173,6 @@ def generate_video(
refresh_preview["image_mask"] = new_image_mask refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0: if window_no == 1 and image_refs is not None and len(image_refs) > 0:
if repeat_no == 1:
frames_positions_list = []
if frames_positions is not None and len(frames_positions)> 0:
positions = frames_positions.split(" ")
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
joker_used = False
project_window_no = 1
for pos in positions :
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
frames_positions_list.append(cur_end_pos)
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
else:
frames_positions_list.append(int(pos)-1 + alignment_shift)
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from shared.utils.utils import get_outpainting_full_area_dimensions from shared.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size w, h = image_refs[0].size
@ -5089,20 +5196,16 @@ def generate_video(
if remove_background_images_ref > 0: if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref, remove_background_images_ref > 0, any_background_ref,
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
block_size=block_size, block_size=block_size,
outpainting_dims =outpainting_dims ) outpainting_dims =outpainting_dims,
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
refresh_preview["image_refs"] = image_refs refresh_preview["image_refs"] = image_refs
if nb_frames_positions > 0:
frames_to_inject = [None] * (max(frames_positions_list) + 1)
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if vace : if vace :
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame]
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
@ -5116,7 +5219,7 @@ def generate_video(
any_background_ref = any_background_ref any_background_ref = any_background_ref
) )
if len(frames_to_inject_parsed) or any_background_ref: if len(frames_to_inject_parsed) or any_background_ref:
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref: if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
else: else:
@ -5166,9 +5269,13 @@ def generate_video(
image_start = image_start_tensor, image_start = image_start_tensor,
image_end = image_end_tensor, image_end = image_end_tensor,
input_frames = src_video, input_frames = src_video,
input_frames2 = src_video2,
input_ref_images= src_ref_images, input_ref_images= src_ref_images,
input_ref_masks = src_ref_masks,
input_masks = src_mask, input_masks = src_mask,
input_masks2 = src_mask2,
input_video= pre_video_guide, input_video= pre_video_guide,
input_faces = src_faces,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
frame_num= (current_video_length // latent_size)* latent_size + 1, frame_num= (current_video_length // latent_size)* latent_size + 1,
@ -5302,6 +5409,7 @@ def generate_video(
send_cmd("output") send_cmd("output")
else: else:
sample = samples.cpu() sample = samples.cpu()
abort = not is_image and sample.shape[1] < current_video_length
# if True: # for testing # if True: # for testing
# torch.save(sample, "output.pt") # torch.save(sample, "output.pt")
# else: # else:
@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type old_video_prompt_type = video_prompt_type
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV") video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type visible = "V" in video_prompt_type
model_type = state["model_type"] model_type = state["model_type"]
@ -7437,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
image_prompt_type_choices = [] image_prompt_type_choices = []
if "T" in image_prompt_types_allowed: if "T" in image_prompt_types_allowed:
image_prompt_type_choices += [("Text Prompt Only", "")] image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")]
if "S" in image_prompt_types_allowed: if "S" in image_prompt_types_allowed:
image_prompt_type_choices += [("Start Video with Image", "S")] image_prompt_type_choices += [("Start Video with Image", "S")]
any_start_image = True any_start_image = True
@ -7516,7 +7624,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown( video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices, guide_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ), value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ),
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
) )
any_control_video = True any_control_video = True
@ -7560,13 +7668,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
} }
mask_preprocessing_choices = [] mask_preprocessing_choices = []
mask_preprocessing_labels = guide_preprocessing.get("labels", {}) mask_preprocessing_labels = mask_preprocessing.get("labels", {})
for process_type in mask_preprocessing["selection"]: for process_type in mask_preprocessing["selection"]:
process_label = mask_preprocessing_labels.get(process_type, None) process_label = mask_preprocessing_labels.get(process_type, None)
process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
mask_preprocessing_choices.append( (process_label, process_type) ) mask_preprocessing_choices.append( (process_label, process_type) )
video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed")
video_prompt_type_video_mask = gr.Dropdown( video_prompt_type_video_mask = gr.Dropdown(
mask_preprocessing_choices, mask_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
@ -7591,7 +7699,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices= image_ref_choices["choices"], choices= image_ref_choices["choices"],
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
visible = image_ref_choices.get("visible", True), visible = image_ref_choices.get("visible", True),
label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2
) )
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
@ -7634,7 +7742,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group(): with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -7649,14 +7757,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "")
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects") background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects")
remove_background_images_ref = gr.Dropdown( remove_background_images_ref = gr.Dropdown(
choices=[ choices=[
@ -7664,7 +7772,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
(background_removal_label, 1), (background_removal_label, 1),
], ],
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
) )
any_audio_voices_support = any_audio_track(base_model_type) any_audio_voices_support = any_audio_track(base_model_type)
@ -8084,7 +8192,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Aligned to the beginning of the First Window of the new Video Sample", "T"), ("Aligned to the beginning of the First Window of the new Video Sample", "T"),
], ],
value=filter_letters(video_prompt_type_value, "T"), value=filter_letters(video_prompt_type_value, "T"),
label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue", label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue",
visible = vace or ltxv or t2v or infinitetalk visible = vace or ltxv or t2v or infinitetalk
) )