commit in case there is an unrecoverable code hemorragy

This commit is contained in:
DeepBeepMeep 2025-09-22 17:11:25 +02:00
parent fc615ffb3c
commit 84010bd861
17 changed files with 1620 additions and 380 deletions

15
configs/animate.json Normal file
View File

@ -0,0 +1,15 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 36,
"model_type": "i2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"motion_encoder_dim": 512
}

14
configs/lucy_edit.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.33.0",
"dim": 3072,
"eps": 1e-06,
"ffn_dim": 14336,
"freq_dim": 256,
"in_dim": 96,
"model_type": "ti2v2_2",
"num_heads": 24,
"num_layers": 30,
"out_dim": 48,
"text_len": 512
}

13
defaults/animate.json Normal file
View File

@ -0,0 +1,13 @@
{
"model": {
"name": "Wan2.2 Animate",
"architecture": "animate",
"description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
],
"group": "wan2_2"
}
}

18
defaults/lucy_edit.json Normal file
View File

@ -0,0 +1,18 @@
{
"model": {
"name": "Wan2.2 Lucy Edit 5B",
"architecture": "lucy_edit",
"description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
],
"group": "wan2_2"
},
"video_length": 81,
"guidance_scale": 5,
"flow_shift": 5,
"num_inference_steps": 30,
"resolution": "1280x720"
}

View File

@ -7,6 +7,7 @@
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2"
},
"prompt" : "Put the person into a clown outfit.",
"video_length": 121,
"guidance_scale": 1,
"flow_shift": 3,

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mmgp import offload
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
@ -387,7 +386,8 @@ class QwenImagePipeline(): #DiffusionPipeline
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
@ -479,7 +479,7 @@ class QwenImagePipeline(): #DiffusionPipeline
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, 1, num_channels_latents, height, width)
shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None
if image is not None:
@ -499,10 +499,7 @@ class QwenImagePipeline(): #DiffusionPipeline
else:
image_latents = torch.cat([image_latents], dim=0)
image_latent_height, image_latent_width = image_latents.shape[3:]
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
)
image_latents = self._pack_latents(image_latents)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@ -511,7 +508,7 @@ class QwenImagePipeline(): #DiffusionPipeline
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = self._pack_latents(latents)
else:
latents = latents.to(device=device, dtype=dtype)
@ -713,11 +710,12 @@ class QwenImagePipeline(): #DiffusionPipeline
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
height, width = image_height, image_width
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
image_mask_latents = self._pack_latents(image_mask_latents)
prompt_image = image
if image.size != (image_width, image_height):
@ -822,6 +820,7 @@ class QwenImagePipeline(): #DiffusionPipeline
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
morph, first_step = False, 0
lanpaint_proc = None
if image_mask_latents is not None:
randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.:
@ -833,7 +832,8 @@ class QwenImagePipeline(): #DiffusionPipeline
timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
# from shared.inpainting.lanpaint import LanPaint
# lanpaint_proc = LanPaint()
# 6. Denoising loop
self.scheduler.set_begin_index(0)
updated_num_steps= len(timesteps)
@ -847,48 +847,52 @@ class QwenImagePipeline(): #DiffusionPipeline
offload.set_step_no_for_lora(self.transformer, first_step + i)
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = latents
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
latents_dtype = latents.dtype
if do_true_cfg and joint_pass:
noise_pred, neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)
if noise_pred == None: return None
noise_pred = noise_pred[:, : latents.size(1)]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
else:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if noise_pred == None: return None
noise_pred = noise_pred[:, : latents.size(1)]
# latent_model_input = latents
def denoise(latent_model_input, true_cfg_scale):
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1)
do_true_cfg = true_cfg_scale > 1
if do_true_cfg and joint_pass:
noise_pred, neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance, #!!!!
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)
if noise_pred == None: return None, None
noise_pred = noise_pred[:, : latents.size(1)]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
else:
neg_noise_pred = None
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask_list=[prompt_embeds_mask],
encoder_hidden_states_list=[prompt_embeds],
img_shapes=img_shapes,
txt_seq_lens_list=[txt_seq_lens],
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if noise_pred == None: return None, None
noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
neg_noise_pred = self.transformer(
@ -902,27 +906,43 @@ class QwenImagePipeline(): #DiffusionPipeline
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
if neg_noise_pred == None: return None
if neg_noise_pred == None: return None, None
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
return noise_pred, neg_noise_pred
def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
if do_true_cfg:
comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
if comb_pred == None: return None
if do_true_cfg:
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
if comb_pred == None: return None
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)
neg_noise_pred = None
return noise_pred
if lanpaint_proc is not None and i<=3:
latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
if latents is None: return None
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None
noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
neg_noise_pred = None
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
noise_pred = None
if image_mask_latents is not None:
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
latent_noise_factor = next_t / 1000
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
noisy_image = None
if lanpaint_proc is not None:
latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents
else:
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
latent_noise_factor = next_t / 1000
# noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
noisy_image = None
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():

View File

@ -32,9 +32,10 @@ from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
from shared.utils.vace_preprocessor import VaceVideoProcessor
from shared.utils.basic_flowmatch import FlowMatchScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
from mmgp import safetensors2
from shared.utils.audio_video import save_video
def optimized_scale(positive_flat, negative_flat):
@ -93,7 +94,7 @@ class WanAny2V:
shard_fn= None)
# base_model_type = "i2v2_2"
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]:
if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]:
self.clip = CLIPModel(
dtype=config.clip_dtype,
device=self.device,
@ -102,7 +103,7 @@ class WanAny2V:
tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large"))
if base_model_type in ["ti2v_2_2"]:
if base_model_type in ["ti2v_2_2", "lucy_edit"]:
self.vae_stride = (4, 16, 16)
vae_checkpoint = "Wan2.2_VAE.safetensors"
vae = Wan2_2_VAE
@ -190,19 +191,14 @@ class WanAny2V:
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
self.sample_neg_prompt = config.sample_neg_prompt
if self.model.config.get("vace_in_dim", None) != None:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
if hasattr(self.model, "vace_blocks"):
self.adapt_vace_model(self.model)
if self.model2 is not None: self.adapt_vace_model(self.model2)
if hasattr(self.model, "face_adapter"):
self.adapt_animate_model(self.model)
if self.model2 is not None: self.adapt_animate_model(self.model2)
self.num_timesteps = 1000
self.use_timestep_transform = True
@ -277,51 +273,6 @@ class WanAny2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
from shared.utils.utils import save_image
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img) if return_mask else None
else:
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
else:
canvas_height, canvas_width = image_size
if full_frame:
new_height = canvas_height
new_width = canvas_width
top = left = 0
else:
# if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
# if fill_max and (canvas_width - new_width) < 16:
# new_width = canvas_width
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
return ref_img.to(device), canvas
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = []
@ -375,7 +326,7 @@ class WanAny2V:
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
self.background_mask = None
@ -386,9 +337,9 @@ class WanAny2V:
if ref_img is not None and not torch.is_tensor(ref_img):
if j==0 and any_background_ref:
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
else:
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
if self.background_mask != None:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images
@ -402,12 +353,26 @@ class WanAny2V:
return torch.cat(ref_vae_latents, dim=1)
def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"):
if mask_pixel_values is None:
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
else:
msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest')
if nb_frames_unchanged >0:
msk[:, :nb_frames_unchanged] = 1
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1,2)[0]
return msk
def generate(self,
input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
input_ref_images = None,
input_ref_masks = None,
input_faces = None,
input_video = None,
image_start = None,
image_end = None,
@ -541,14 +506,18 @@ class WanAny2V:
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"]
ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
lucy_edit= model_type in ["lucy_edit"]
animate= model_type in ["animate"]
start_step_no = 0
ref_images_count = 0
trim_frames = 0
extended_overlapped_latents = None
extended_overlapped_latents = clip_image_start = clip_image_end = None
no_noise_latents_injection = infinitetalk
timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
extended_input_dim = 0
ref_images_before = False
# image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False
@ -598,17 +567,7 @@ class WanAny2V:
if image_end is not None:
img_end_frame = image_end.unsqueeze(1).to(self.device)
if hasattr(self, "clip"):
clip_image_size = self.clip.model.image_size
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start
if model_type == "flf2v_720p":
clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([image_start[:, None, :, :]])
else:
clip_context = None
clip_image_start, clip_image_end = image_start, image_end
if any_end_frame:
enc= torch.concat([
@ -647,21 +606,62 @@ class WanAny2V:
if infinitetalk:
lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
# if control_pre_frames_count != pre_frames_count:
lat_y = input_video = None
kwargs.update({ 'y': y})
if not clip_context is None:
kwargs.update({'clip_fea': clip_context})
# Recam Master
if recam:
target_camera = model_mode
height,width = input_frames.shape[-2:]
input_frames = input_frames.to(dtype=self.dtype , device=self.device)
source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
# Animate
if animate:
pose_pixels = input_frames * input_masks
input_masks = 1. - input_masks
pose_pixels -= input_masks
save_video(pose_pixels, "pose.mp4")
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
save_video(input_frames, "input_frames.mp4")
save_video(input_masks, "input_masks.mp4", value_range=(0,1))
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
msk = torch.concat([msk_ref, msk_control], dim=1)
clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
lat_y = msk = msk_control = msk_ref = pose_pixels = None
ref_images_before = True
ref_images_count = 1
lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
# Clip image
if hasattr(self, "clip") and clip_image_start is not None:
clip_image_size = self.clip.model.image_size
clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
if model_type == "flf2v_720p":
clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
else:
clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
clip_image_start = clip_image_end = None
kwargs.update({'clip_fea': clip_context})
# Recam Master & Lucy Edit
if recam or lucy_edit:
frame_num, height,width = input_frames.shape[-3:]
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
frame_num = (lat_frames -1) * self.vae_stride[0] + 1
input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
extended_input_dim = 2 if recam else 1
del input_frames
if recam:
# Process target camera (recammaster)
target_camera = model_mode
from shared.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
@ -715,6 +715,8 @@ class WanAny2V:
height, width = input_video.shape[-2:]
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
timestep_injection = True
if extended_input_dim > 0:
extended_latents[:, :, :source_latents.shape[2]] = source_latents
# Vace
if vace :
@ -722,6 +724,7 @@ class WanAny2V:
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
ref_images_before = True
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
@ -771,9 +774,9 @@ class WanAny2V:
expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes
if target_camera != None:
if extended_input_dim>=2:
shape = list(target_shape[1:])
shape[0] *= 2
shape[extended_input_dim-2] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
@ -901,8 +904,8 @@ class WanAny2V:
for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
if target_camera != None:
latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2)
if extended_input_dim > 0:
latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
else:
latent_model_input = latents
@ -1030,7 +1033,7 @@ class WanAny2V:
if callback is not None:
latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
@ -1041,7 +1044,7 @@ class WanAny2V:
if timestep_injection:
latents[:, :, :source_latents.shape[2]] = source_latents
if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone()
@ -1078,4 +1081,12 @@ class WanAny2V:
delattr(model, "vace_blocks")
def adapt_animate_model(self, model):
modules_dict= { k: m for k, m in model.named_modules()}
for animate_layer in range(8):
module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
model_layer = animate_layer * 5
target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "face_adapter_fuser_blocks", module )
delattr(model, "face_adapter")

View File

@ -16,6 +16,9 @@ from mmgp.offload import get_cache, clear_caches
from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target
from ..animate.motion_encoder import Generator
from ..animate.face_blocks import FaceAdapter, FaceEncoder
from ..animate.model_animate import after_patch_embedding
__all__ = ['WanModel']
@ -499,6 +502,7 @@ class WanAttentionBlock(nn.Module):
multitalk_masks=None,
ref_images_count=0,
standin_phase=-1,
motion_vec = None,
):
r"""
Args:
@ -616,6 +620,10 @@ class WanAttentionBlock(nn.Module):
x.add_(hint)
else:
x.add_(hint, alpha= scale)
if motion_vec is not None and self.block_no % 5 == 0:
x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
return x
class AudioProjModel(ModelMixin, ConfigMixin):
@ -898,6 +906,7 @@ class WanModel(ModelMixin, ConfigMixin):
norm_input_visual=True,
norm_output_audio=True,
standin= False,
motion_encoder_dim=0,
):
super().__init__()
@ -922,14 +931,15 @@ class WanModel(ModelMixin, ConfigMixin):
self.flag_causal_attention = False
self.block_mask = None
self.inject_sample_info = inject_sample_info
self.motion_encoder_dim = motion_encoder_dim
self.norm_output_audio = norm_output_audio
self.audio_window = audio_window
self.intermediate_dim = intermediate_dim
self.vae_scale = vae_scale
multitalk = multitalk_output_dim > 0
self.multitalk = multitalk
self.multitalk = multitalk
animate = motion_encoder_dim > 0
# embeddings
self.patch_embedding = nn.Conv3d(
@ -1027,6 +1037,25 @@ class WanModel(ModelMixin, ConfigMixin):
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
if animate:
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
)
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
self.face_adapter = FaceAdapter(
heads_num=self.num_heads,
hidden_dim=self.dim,
num_adapter_layers=self.num_layers // 5,
)
self.face_encoder = FaceEncoder(
in_dim=motion_encoder_dim,
hidden_dim=self.dim,
num_heads=4,
)
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype
@ -1208,6 +1237,9 @@ class WanModel(ModelMixin, ConfigMixin):
ref_images_count = 0,
standin_freqs = None,
standin_ref = None,
pose_latents=None,
face_pixel_values=None,
):
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
@ -1240,9 +1272,18 @@ class WanModel(ModelMixin, ConfigMixin):
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1)
# embeddings
# x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:]
x_list[i] = x
y = None
motion_vec_list = []
for i, x in enumerate(x_list):
# animate embeddings
motion_vec = None
if pose_latents is not None:
x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
x_og_shape = x.shape
@ -1250,7 +1291,7 @@ class WanModel(ModelMixin, ConfigMixin):
else:
x = x.flatten(2).transpose(1, 2)
x_list[i] = x
x, y = None, None
x = None
block_mask = None
@ -1450,9 +1491,9 @@ class WanModel(ModelMixin, ConfigMixin):
continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else:
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
if should_calc:
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
del x
context = hints = audio_embedding = None

View File

@ -3,10 +3,10 @@ import numpy as np
import gradio as gr
def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
def text_oneframe_overlap(base_model_type):
return test_class_i2v(base_model_type) and not test_multitalk(base_model_type)
return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@ -17,6 +17,8 @@ def test_multitalk(base_model_type):
def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
def test_wan_5B(base_model_type):
return base_model_type in ["ti2v_2_2", "lucy_edit"]
class family_handler():
@staticmethod
@ -36,7 +38,7 @@ class family_handler():
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
elif base_model_type in ["i2v_2_2"]:
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
elif base_model_type in ["ti2v_2_2"]:
elif test_wan_5B(base_model_type):
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
else: # i2v
@ -83,11 +85,13 @@ class family_handler():
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class
if test_multitalk(base_model_type):
if base_model_type in ["animate"]:
fps = 30
elif test_multitalk(base_model_type):
fps = 25
elif base_model_type in ["fantasy"]:
fps = 23
elif base_model_type in ["ti2v_2_2"]:
elif test_wan_5B(base_model_type):
fps = 24
else:
fps = 16
@ -100,14 +104,14 @@ class family_handler():
extra_model_def.update({
"frames_minimum" : frames_minimum,
"frames_steps" : frames_steps,
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
"sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2",
"multiple_submodels" : multiple_submodels,
"guidance_max_phases" : 3,
"skip_layer_guidance" : True,
"cfg_zero" : True,
"cfg_star" : True,
"adaptive_projected_guidance" : True,
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"convert_image_guide_to_video" : True,
@ -146,6 +150,34 @@ class family_handler():
}
# extra_model_def["at_least_one_image_ref_needed"] = True
if base_model_type in ["lucy_edit"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["guide_preprocessing"] = {
"selection": ["UV"],
"labels" : { "UV": "Control Video"},
"visible": False,
}
if base_model_type in ["animate"]:
extra_model_def["guide_custom_choices"] = {
"choices":[
("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"),
("Replace Person in Control Video Person in Reference Image", "PVBAI"),
],
"default": "KI",
"letters_filter": "PVBXAKI",
"label": "Type of Process",
"show_label" : False,
}
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
extra_model_def["background_ref_outpainted"] = False
if vace_class:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
@ -157,16 +189,19 @@ class family_handler():
extra_model_def["image_ref_choices"] = {
"choices": [("None", ""),
("Inject only People / Objects", "I"),
("Inject Landscape and then People / Objects", "KI"),
("Inject Frames and then People / Objects", "FI"),
("People / Objects", "I"),
("Landscape followed by People / Objects (if any)", "KI"),
("Positioned Frames followed by People / Objects (if any)", "FI"),
],
"letters_filter": "KFI",
}
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames"
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["pad_guide_video"] = True
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True
@ -209,10 +244,12 @@ class family_handler():
"visible" : False,
}
if vace_class or base_model_type in ["infinitetalk"]:
if vace_class or base_model_type in ["infinitetalk", "animate"]:
image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL"
elif base_model_type in ["lucy_edit"]:
image_prompt_types_allowed = "TVL"
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
image_prompt_types_allowed = "SVL"
elif i2v:
@ -234,8 +271,8 @@ class family_handler():
def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
"recam_1.3B", "animate",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod
@ -265,11 +302,12 @@ class family_handler():
@staticmethod
def get_vae_block_size(base_model_type):
return 32 if base_model_type == "ti2v_2_2" else 16
return 32 if test_wan_5B(base_model_type) else 16
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@ -283,7 +321,7 @@ class family_handler():
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
}]
if base_model_type == "ti2v_2_2":
if test_wan_5B(base_model_type):
download_def += [ {
"repoId" : "DeepBeepMeep/Wan2.2",
"sourceFolderList" : [""],
@ -377,8 +415,8 @@ class family_handler():
ui_defaults.update({
"sample_solver": "unipc",
})
if test_class_i2v(base_model_type):
ui_defaults["image_prompt_type"] = "S"
if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
ui_defaults["image_prompt_type"] = "S"
if base_model_type in ["fantasy"]:
ui_defaults.update({
@ -434,10 +472,15 @@ class family_handler():
"image_prompt_type": "T",
})
if base_model_type in ["recam_1.3B"]:
if base_model_type in ["recam_1.3B", "lucy_edit"]:
ui_defaults.update({
"video_prompt_type": "UV",
})
elif base_model_type in ["animate"]:
ui_defaults.update({
"video_prompt_type": "PVBXAKI",
"mask_expand": 20,
})
if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1

View File

@ -0,0 +1,342 @@
#!/usr/bin/env python3
"""
Convert a Flux model from Diffusers (folder or single-file) into the original
single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
Output : /path/to/flux1-your-model.safetensors (transformer only)
Usage:
python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
# optional quantization:
# --fp8 (float8_e4m3fn, simple)
# --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
"""
import argparse
import json
from pathlib import Path
from collections import OrderedDict
import torch
from safetensors import safe_open
import safetensors.torch
from tqdm import tqdm
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("diffusers_path", type=str,
help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
ap.add_argument("output_path", type=str,
help="Output .safetensors path for the Flux transformer.")
ap.add_argument("--fp8", action="store_true",
help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
ap.add_argument("--fp8-scaled", action="store_true",
help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
return ap.parse_args()
# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
DIFFUSERS_MAP = {
# global embeds
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
# dual-stream (image/text) blocks
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
# single-stream blocks
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
# final
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
# these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
class DiffusersSource:
"""
Uniform interface over:
1) Folder with index JSON + shards
2) Folder with exactly one .safetensors (no index)
3) Single .safetensors file
Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
"""
POSSIBLE_PREFIXES = ["", "model."] # try in this order
def __init__(self, path: Path):
p = Path(path)
if p.is_dir():
# use 'transformer' subfolder if present
if (p / "transformer").is_dir():
p = p / "transformer"
self._init_from_dir(p)
elif p.is_file() and p.suffix == ".safetensors":
self._init_from_single_file(p)
else:
raise FileNotFoundError(f"Invalid path: {p}")
# ---------- common helpers ----------
@staticmethod
def _strip_prefix(k: str) -> str:
return k[6:] if k.startswith("model.") else k
def _resolve(self, want: str):
"""
Return the actual stored key matching `want` by trying known prefixes.
"""
for pref in self.POSSIBLE_PREFIXES:
k = pref + want
if k in self._all_keys:
return k
return None
def has(self, want: str) -> bool:
return self._resolve(want) is not None
def get(self, want: str) -> torch.Tensor:
real_key = self._resolve(want)
if real_key is None:
raise KeyError(f"Missing key: {want}")
return self._get_by_real_key(real_key).to("cpu")
@property
def base_keys(self):
# keys without 'model.' prefix for scanning
return [self._strip_prefix(k) for k in self._all_keys]
# ---------- modes ----------
def _init_from_single_file(self, file_path: Path):
self._mode = "single"
self._file = file_path
self._handle = safe_open(file_path, framework="pt", device="cpu")
self._all_keys = list(self._handle.keys())
def _get_by_real_key(real_key: str):
return self._handle.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
def _init_from_dir(self, dpath: Path):
index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
if index_json.exists():
with open(index_json, "r", encoding="utf-8") as f:
index = json.load(f)
weight_map = index["weight_map"] # full mapping
self._mode = "sharded"
self._dpath = dpath
self._weight_map = {k: dpath / v for k, v in weight_map.items()}
self._all_keys = list(self._weight_map.keys())
self._open_handles = {}
def _get_by_real_key(real_key: str):
fpath = self._weight_map[real_key]
h = self._open_handles.get(fpath)
if h is None:
h = safe_open(fpath, framework="pt", device="cpu")
self._open_handles[fpath] = h
return h.get_tensor(real_key)
self._get_by_real_key = _get_by_real_key
return
# no index: try exactly one safetensors in folder
files = sorted(dpath.glob("*.safetensors"))
if len(files) != 1:
raise FileNotFoundError(
f"No index found and {dpath} does not contain exactly one .safetensors file."
)
self._init_from_single_file(files[0])
def main():
args = parse_args()
src = DiffusersSource(Path(args.diffusers_path))
# Count blocks by scanning base keys (with any 'model.' prefix removed)
num_dual = 0
num_single = 0
for k in src.base_keys:
if k.startswith("transformer_blocks."):
try:
i = int(k.split(".")[1])
num_dual = max(num_dual, i + 1)
except Exception:
pass
elif k.startswith("single_transformer_blocks."):
try:
i = int(k.split(".")[1])
num_single = max(num_single, i + 1)
except Exception:
pass
print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
# Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
shift, scale = vec.chunk(2, dim=0)
return torch.cat([scale, shift], dim=0)
orig = {}
# Per-block (dual)
for b in range(num_dual):
prefix = f"transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("double_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Per-block (single)
for b in range(num_single):
prefix = f"single_transformer_blocks.{b}."
for okey, dvals in DIFFUSERS_MAP.items():
if not okey.startswith("single_blocks."):
continue
dkeys = [prefix + v for v in dvals]
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey.replace("()", str(b))] = src.get(dkeys[0])
else:
orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Globals (non-block)
for okey, dvals in DIFFUSERS_MAP.items():
if okey.startswith(("double_blocks.", "single_blocks.")):
continue
dkeys = dvals
if not all(src.has(k) for k in dkeys):
continue
if len(dkeys) == 1:
orig[okey] = src.get(dkeys[0])
else:
orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
# Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
if "final_layer.adaLN_modulation.1.weight" in orig:
orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.weight"]
)
if "final_layer.adaLN_modulation.1.bias" in orig:
orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
orig["final_layer.adaLN_modulation.1.bias"]
)
# Optional FP8 variants (experimental; not required for ComfyUI/BFL)
if args.fp8 or args.fp8_scaled:
dtype = torch.float8_e4m3fn # noqa
minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
def stochastic_round_to(t):
t = t.float().clamp(minv, maxv)
lower = torch.floor(t * 256) / 256
upper = torch.ceil(t * 256) / 256
prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
rnd = torch.rand_like(t)
out = torch.where(rnd < prob, upper, lower)
return out.to(dtype)
def scale_to_8bit(weight, target_max=416.0):
absmax = weight.abs().max()
scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
scaled = (weight / scale).clamp(minv, maxv).to(dtype)
return scaled, scale
scales = {}
for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
t = orig[k]
if args.fp8:
orig[k] = stochastic_round_to(t)
else:
if k.endswith(".weight") and t.dim() == 2:
qt, s = scale_to_8bit(t)
orig[k] = qt
scales[k[:-len(".weight")] + ".scale_weight"] = s
else:
orig[k] = t.clamp(minv, maxv).to(dtype)
if args.fp8_scaled:
orig.update(scales)
orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
else:
# Default: save in bfloat16
for k in list(orig.keys()):
orig[k] = orig[k].to(torch.bfloat16).cpu()
out_path = Path(args.output_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
meta = OrderedDict()
meta["format"] = "pt"
meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
print(f"Saving transformer to: {out_path}")
safetensors.torch.save_file(orig, str(out_path), metadata=meta)
print("Done.")
if __name__ == "__main__":
main()

View File

View File

@ -0,0 +1,240 @@
import torch
from .utils import *
from functools import partial
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
def _pack_latents(latents):
batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def _unpack_latents(latents, height, width, vae_scale_factor=8):
batch_size, num_patches, channels = latents.shape
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
class LanPaint():
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
self.n_steps = NSteps
self.chara_lamb = Lambda
self.IS_FLUX = IS_FLUX
self.IS_FLOW = IS_FLOW
self.step_size = StepSize
self.friction = Friction
self.chara_beta = Beta
self.img_dim_size = None
def add_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (None,) * (self.img_dim_size-1)
return array[index]
def remove_none_dims(self, array):
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
index = (slice(None),) + (0,) * (self.img_dim_size-1)
return array[index]
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
self.height = height
self.width = width
self.vae_scale_factor = vae_scale_factor
self.img_dim_size = len(x.shape)
self.latent_image = latent_image
self.noise = noise
if n_steps is None:
n_steps = self.n_steps
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
out = _pack_latents(out)
return out
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
if IS_FLUX:
cfg_BIG = 1.0
def double_denoise(latents, t):
latents = _pack_latents(latents)
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
if noise_pred == None: return None, None
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
if true_cfg_scale == cfg_BIG:
predict_big = predict_std
else:
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
return predict_std, predict_big
if len(sigma.shape) == 0:
sigma = torch.tensor([sigma.item()])
latent_mask = 1 - latent_mask
if IS_FLUX or IS_FLOW:
Flow_t = sigma
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
VE_Sigma = Flow_t / (1 - Flow_t)
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
else:
VE_Sigma = sigma
abt = 1/( 1+VE_Sigma**2 )
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
# VE_Sigma, abt, Flow_t = current_times
current_times = (VE_Sigma, abt, Flow_t)
step_size = self.step_size * (1 - abt)
step_size = self.add_none_dims(step_size)
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
# This is the replace step
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
x = x * (1 - latent_mask) + noisy_image * latent_mask
if IS_FLUX or IS_FLOW:
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations Start ###############
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
args = None
for i in range(n_steps):
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
if score_func is None: return None
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
if IS_FLUX or IS_FLOW:
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
else:
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
############ LanPaint Iterations End ###############
# out is x_0
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
# out = out * (1-latent_mask) + self.latent_image * latent_mask
# return out
return x
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
lamb = self.chara_lamb
if self.IS_FLUX or self.IS_FLOW:
# compute t for flow model, with a small epsilon compensating for numerical error.
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
if x_0 is None: return None
else:
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
if x_0 is None: return None
score_x = -(x_t - x_0)
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
return score_x * (1 - mask) + score_y * mask
def sigma_x(self, abt):
# the time scale for the x_t update
return abt**0
def sigma_y(self, abt):
beta = self.chara_beta * abt ** 0
return beta
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
# prepare the step size and time parameters
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
# print('mask',mask.device)
if torch.mean(dtx) <= 0.:
return x_t, args
# -------------------------------------------------------------------------
# Compute the Langevin dynamics update in variance perserving notation
# -------------------------------------------------------------------------
#x0 = self.x0_evalutation(x_t, score, sigma, args)
#C = abt**0.5 * x0 / (1-abt)
A = A_x * (1-mask) + A_y * mask
D = D_x * (1-mask) + D_y * mask
dt = dtx * (1-mask) + dty * mask
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
def Coef_C(x_t):
x0 = self.x0_evalutation(x_t, score, sigma, args)
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
return C
def advance_time(x_t, v, dt, Gamma, A, C, D):
dtype = x_t.dtype
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
x_t, v = osc.dynamics(x_t, v, dt )
x_t = x_t.to(dtype)
v = v.to(dtype)
return x_t, v
if args is None:
#v = torch.zeros_like(x_t)
v = None
C = Coef_C(x_t)
#print(torch.squeeze(dtx), torch.squeeze(dty))
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
else:
v, C = args
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C_new = Coef_C(x_t)
v = v + Gamma**0.5 * ( C_new - C) *dt
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
C = C_new
return x_t, (v, C)
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
# -------------------------------------------------------------------------
# Unpack current times parameters (sigma and abt)
sigma, abt, flow_t = current_times
sigma = self.add_none_dims(sigma)
abt = self.add_none_dims(abt)
# Compute time step (dtx, dty) for x and y branches.
dtx = 2 * step_size * sigma_x
dty = 2 * step_size * sigma_y
# -------------------------------------------------------------------------
# Define friction parameter Gamma_hat for each branch.
# Using dtx**0 provides a tensor of the proper device/dtype.
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
# adjust dt to match denoise-addnoise steps sizes
Gamma_hat_x /= 2.
Gamma_hat_y /= 2.
A_t_x = (1) / ( 1 - abt ) * dtx / 2
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
A_x = A_t_x / (dtx/2)
A_y = A_t_y / (dty/2)
Gamma_x = Gamma_hat_x / (dtx/2)
Gamma_y = Gamma_hat_y / (dty/2)
#D_x = (2 * (1 + sigma**2) )**0.5
#D_y = (2 * (1 + sigma**2) )**0.5
D_x = (2 * abt**0 )**0.5
D_y = (2 * abt**0 )**0.5
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
def x0_evalutation(self, x_t, score, sigma, args):
x0 = x_t + score(x_t)
return x0

301
shared/inpainting/utils.py Normal file
View File

@ -0,0 +1,301 @@
import torch
def epxm1_x(x):
# Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
result = torch.special.expm1(x) / x
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x) < 1e-2
result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
return result
def epxm1mx_x2(x):
# Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x) / x**2
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**2) < 1e-2
result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
return result
def expm1mxmhx2_x3(x):
# Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
# replace NaN or inf values with 0
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
mask = torch.abs(x**3) < 1e-2
result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
return result
def exp_1mcosh_GD(gamma_t, delta):
"""
Compute e^(-Γt) * (1 - cosh(ΓtΔ))/ ( (Γt)**2 Δ )
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) )
numerator = torch.where(is_positive, numerator_pos, numerator_neg)
result = numerator / (delta * gamma_t**2 )
# Handle NaN/inf cases
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Handle numerical instability for small delta
mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
result = torch.where(mask, taylor, result)
return result
def exp_sinh_GsqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / (ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
# Main computation
is_positive = delta > 0
sqrt_abs_delta = torch.sqrt(torch.abs(delta))
gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
denominator_pos = gamma_t_sqrt_delta
result_pos = numerator_pos / gamma_t_sqrt_delta
result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
# Taylor expansion for small gamma_t_sqrt_delta
mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
result_pos = torch.where(mask, taylor, result_pos)
# Handle negative delta
result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
result = torch.where(is_positive, result_pos, result_neg)
return result
def exp_cosh(gamma_t, delta):
"""
Compute e^(-Γt) * cosh(ΓtΔ)
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
return result
def exp_sinh_sqrtD(gamma_t, delta):
"""
Compute e^(-Γt) * sinh(ΓtΔ) / Δ
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
result = gamma_t * exp_sinh_GsqrtD_result
return result
def zeta1(gamma_t, delta):
# Compute hyperbolic terms and exponential
half_gamma_t = gamma_t / 2
exp_cosh_term = exp_cosh(half_gamma_t, delta)
exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
# Main computation
numerator = 1 - (exp_cosh_term + exp_sinh_term)
denominator = gamma_t * (1 - delta) / 4
result = 1 - numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small x (similar to your epxm1Dx approach)
mask = torch.abs(denominator) < 5e-3
term1 = epxm1_x(-gamma_t)
term2 = epxm1mx_x2(-gamma_t)
term3 = expm1mxmhx2_x3(-gamma_t)
taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
result = torch.where(mask, taylor, result)
return result
def exp_cosh_minus_terms(gamma_t, delta):
"""
Compute E^(-) * (Cosh[] - 1 - (Cosh[Δ] - 1)/Δ) / ((1 - Δ))
Parameters:
gamma_t: Γ*t term (could be a scalar or tensor)
delta: Δ term (could be a scalar or tensor)
Returns:
Result of the computation with numerical stability handling
"""
exp_term = torch.exp(-gamma_t)
# Compute individual terms
exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
#exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
# Main computation
numerator = exp_cosh_term - exp_cosh_delta_term
denominator = gamma_t * (1 - delta)
result = numerator / denominator
# Handle numerical instability
result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
# Taylor expansion for small gamma_t and delta near 1
mask = (torch.abs(denominator) < 1e-1)
exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
taylor = (
gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0)
- denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
)
result = torch.where(mask, taylor, result)
return result
def zeta2(gamma_t, delta):
half_gamma_t = gamma_t / 2
return exp_sinh_GsqrtD(half_gamma_t, delta)
def sig11(gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def Zcoefs(gamma_t, delta):
Zeta1 = zeta1(gamma_t, delta)
Zeta2 = zeta2(gamma_t, delta)
sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
amplitude = torch.sqrt(sq_total)
Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5
#cterm = exp_cosh_minus_terms(gamma_t, delta)
#sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
#Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
def Zcoefs_asymp(gamma_t, delta):
A_t = (gamma_t * (1 - delta) )/4
return epxm1_x(- 2 * A_t)
class StochasticHarmonicOscillator:
"""
Simulates a stochastic harmonic oscillator governed by the equations:
dy(t) = q(t) dt
dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
Also define v(t) = q(t) / Γ, which is numerically more stable.
Where:
y(t) - Position variable
q(t) - Velocity variable
Γ - Damping coefficient
A - Harmonic potential strength
C - Constant force term
D - Noise amplitude
dw(t) - Wiener process (Brownian motion)
"""
def __init__(self, Gamma, A, C, D):
self.Gamma = Gamma
self.A = A
self.C = C
self.D = D
self.Delta = 1 - 4 * A / Gamma
def sig11(self, gamma_t, delta):
return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
def sig22(self, gamma_t, delta):
return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta)
def dynamics(self, y0, v0, t):
"""
Calculates the position and velocity variables at time t.
Parameters:
y0 (float): Initial position
v0 (float): Initial velocity v(0) = q(0) / Γ
t (float): Time at which to evaluate the dynamics
Returns:
tuple: (y(t), v(t))
"""
dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
Delta = self.Delta + dummyzero
Gamma_hat = self.Gamma * t + dummyzero
A = self.A + dummyzero
C = self.C + dummyzero
D = self.D + dummyzero
Gamma = self.Gamma + dummyzero
zeta_1 = zeta1( Gamma_hat, Delta)
zeta_2 = zeta2( Gamma_hat, Delta)
EE = 1 - Gamma_hat * zeta_2
if v0 is None:
v0 = torch.randn_like(y0) * D / 2 ** 0.5
#v0 = (C - A * y0)/Gamma**0.5
# Calculate mean position and velocity
term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
y_mean = term1 + y0
v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
# sample new position and velocity with multivariate normal distribution
batch_shape = y0.shape
cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
cov_matrix[..., 0, 0] = cov_yy
cov_matrix[..., 0, 1] = cov_yv
cov_matrix[..., 1, 0] = cov_yv # symmetric
cov_matrix[..., 1, 1] = cov_vv
# Compute the Cholesky decomposition to get scale_tril
#scale_tril = torch.linalg.cholesky(cov_matrix)
scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
tol = 1e-8
cov_yy = torch.clamp( cov_yy, min = tol )
sd_yy = torch.sqrt( cov_yy )
inv_sd_yy = 1/(sd_yy)
scale_tril[..., 0, 0] = sd_yy
scale_tril[..., 0, 1] = 0.
scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
# check if it matches torch.linalg.
#assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
# Sample correlated noise from multivariate normal
mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
mean[..., 0] = y_mean
mean[..., 1] = v_mean
new_yv = torch.distributions.MultivariateNormal(
loc=mean,
scale_tril=scale_tril
).sample()
return new_yv[...,0], new_yv[...,1]

View File

@ -232,6 +232,9 @@ def save_video(tensor,
retry=5):
"""Save tensor as video with configurable codec and container options."""
if torch.is_tensor(tensor) and len(tensor.shape) == 4:
tensor = tensor.unsqueeze(0)
suffix = f'.{container}'
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
if not cache_file.endswith(suffix):

110
shared/utils/download.py Normal file
View File

@ -0,0 +1,110 @@
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook

View File

@ -1,4 +1,3 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
@ -176,8 +175,9 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
def convert_tensor_to_image(t, frame_no = -1):
t = t[:, frame_no] if frame_no >= 0 else t
def convert_tensor_to_image(t, frame_no = 0):
if len(t.shape) == 4:
t = t[:, frame_no]
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
if rm_background:
session = new_session()
output_list =[]
output_mask_list =[]
for i, img in enumerate(img_list):
width, height = img.size
resized_mask = None
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
if outpainting_dims is not None:
resized_image =img
if outpainting_dims is not None and background_ref_outpainted:
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height):
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
else:
@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
return output_list
output_mask_list.append(resized_mask)
return output_list, output_mask_list
def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
from shared.utils.utils import save_image
inpaint_color = canvas_tf_bg / 127.5 - 1
def str2bool(v):
"""
Convert a string to a boolean.
Supported true values: 'yes', 'true', 't', 'y', '1'
Supported false values: 'no', 'false', 'f', 'n', '0'
Args:
v (str): String to convert.
Returns:
bool: Converted boolean value.
Raises:
argparse.ArgumentTypeError: If the value cannot be converted to boolean.
"""
if isinstance(v, bool):
return v
v_lower = v.lower()
if v_lower in ('yes', 'true', 't', 'y', '1'):
return True
elif v_lower in ('no', 'false', 'f', 'n', '0'):
return False
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img) if return_mask else None
else:
raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
import sys, time
# Global variables to track download progress
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
_update_interval = 0.5 # Update speed every 0.5 seconds
def progress_hook(block_num, block_size, total_size, filename=None):
"""
Simple progress bar hook for urlretrieve
Args:
block_num: Number of blocks downloaded so far
block_size: Size of each block in bytes
total_size: Total size of the file in bytes
filename: Name of the file being downloaded (optional)
"""
global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
current_time = time.time()
downloaded = block_num * block_size
# Initialize timing on first call
if _start_time is None or block_num == 0:
_start_time = current_time
_last_time = current_time
_last_downloaded = 0
_speed_history = []
# Calculate download speed only at specified intervals
speed = 0
if current_time - _last_time >= _update_interval:
if _last_time > 0:
current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
_speed_history.append(current_speed)
# Keep only last 5 speed measurements for smoothing
if len(_speed_history) > 5:
_speed_history.pop(0)
# Average the recent speeds for smoother display
speed = sum(_speed_history) / len(_speed_history)
_last_time = current_time
_last_downloaded = downloaded
elif _speed_history:
# Use the last calculated average speed
speed = sum(_speed_history) / len(_speed_history)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
file_display = filename if filename else "Unknown file"
if total_size <= 0:
# If total size is unknown, show downloaded bytes
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(80))
sys.stdout.flush()
return
downloaded = block_num * block_size
percent = min(100, (downloaded / total_size) * 100)
# Create progress bar (40 characters wide to leave room for other info)
bar_length = 40
filled = int(bar_length * percent / 100)
bar = '' * filled + '' * (bar_length - filled)
# Format file sizes and speed
def format_bytes(bytes_val):
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_val < 1024:
return f"{bytes_val:.1f}{unit}"
bytes_val /= 1024
return f"{bytes_val:.1f}TB"
speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
# Display progress with filename first
line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
# Clear any trailing characters by padding with spaces
sys.stdout.write(line.ljust(100))
sys.stdout.flush()
# Print newline when complete
if percent >= 100:
print()
# Wrapper function to include filename in progress hook
def create_progress_hook(filename):
"""Creates a progress hook with the filename included"""
global _start_time, _last_time, _last_downloaded, _speed_history
# Reset timing variables for new download
_start_time = None
_last_time = None
_last_downloaded = 0
_speed_history = []
def hook(block_num, block_size, total_size):
return progress_hook(block_num, block_size, total_size, filename)
return hook
if outpainting_dims != None:
final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
else:
canvas_height, canvas_width = image_size
if full_frame:
new_height = canvas_height
new_width = canvas_width
top = left = 0
else:
# if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
# if fill_max and (canvas_width - new_width) < 16:
# new_width = canvas_width
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
if return_image:
return convert_tensor_to_image(ref_img), canvas
return ref_img.to(device), canvas
def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
src_videos, src_masks = [], []
inpaint_color = guide_inpaint_color/127.5 - 1
prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
src_video = src_mask = None
if cur_video_guide is not None:
src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
if cur_video_mask is not None and any_mask:
src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
if pre_video_guide is not None and not extract_guide_from_window_start:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask:
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
if src_video is None:
if any_guide_padding:
src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
if any_mask:
src_mask = torch.zeros_like(src_video[0:1])
elif src_video.shape[1] < current_video_length:
if any_guide_padding:
src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
if cur_video_mask is not None and any_mask:
src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
else:
new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
src_video = src_video[:, :new_num_frames]
if any_mask:
src_mask = src_mask[:, :new_num_frames]
for k, keep in enumerate(keep_video_guide_frames):
if not keep:
pos = prepend_count + k
src_video[:, pos:pos+1] = inpaint_color
src_mask[:, pos:pos+1] = 1
for k, frame in enumerate(inject_frames):
if frame != None:
pos = prepend_count + k
src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
src_videos.append(src_video)
src_masks.append(src_mask)
return src_videos, src_masks

228
wgp.py
View File

@ -394,7 +394,7 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
if "F" in video_prompt_type:
if len(frames_positions.strip()) > 0:
positions = frames_positions.split(" ")
positions = frames_positions.replace(","," ").split(" ")
for pos_str in positions:
if not pos_str in ["L", "l"] and len(pos_str)>0:
if not is_integer(pos_str):
@ -2528,7 +2528,7 @@ def download_models(model_filename = None, model_type= None, module_type = False
from urllib.request import urlretrieve
from shared.utils.utils import create_progress_hook
from shared.utils.download import create_progress_hook
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva
input_mask = convert_tensor_to_image(full_frame)
return input_image, input_mask
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
if not input_video_path or max_frames <= 0:
return None, None
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
any_mask = input_mask_path != None
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if len(video) == 0: return None
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
frame_height, frame_width, _ = video[0].shape
num_frames = min(len(video), len(mask_video))
if num_frames == 0: return None
video, mask_video = video[:num_frames], mask_video[:num_frames]
from preprocessing.face_preprocessor import FaceProcessor
face_processor = FaceProcessor()
face_list = []
for frame_idx in range(num_frames):
frame = video[frame_idx].cpu().numpy()
# video[frame_idx] = None
if any_mask:
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy())
# mask_video[frame_idx] = None
if (frame_width, frame_height) != mask.size:
mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS)
mask = np.array(mask)
alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
alpha_mask[mask > 127] = 1
frame = frame * alpha_mask
frame = Image.fromarray(frame)
face = face_processor.process(frame, resize_to=size, face_crop_scale = 1)
face_list.append(face)
face_processor = None
gc.collect()
torch.cuda.empty_cache()
face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w
if pad_frames > 0:
face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2)
if args.save_masks:
from preprocessing.dwpose.pose import save_one_video
saved_faces_frames = [np.array(face) for face in face_list ]
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
return face_tensor
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
@ -3742,7 +3796,13 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box]
return box
inpaint_color = int(inpaint_color)
pad_frames = 0
if start_frame < 0:
pad_frames= -start_frame
max_frames += start_frame
start_frame = 0
if not input_video_path or max_frames <= 0:
return None, None
any_mask = input_mask_path != None
@ -3909,6 +3969,9 @@ def preprocess_video_with_mask(input_video_path, input_mask_path, height, width,
preproc_outside = None
gc.collect()
torch.cuda.empty_cache()
if pad_frames > 0:
masked_frames = masked_frames[0] * pad_frames + masked_frames
if any_mask: masked_frames = masks[0] * pad_frames + masks
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
@ -4646,7 +4709,8 @@ def generate_video(
current_video_length = video_length
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
i2v = test_class_i2v(model_type)
diffusion_forcing = "diffusion_forcing" in model_filename
t2v = base_model_type in ["t2v"]
@ -4662,6 +4726,7 @@ def generate_video(
multitalk = model_def.get("multitalk_class", False)
standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
animate = base_model_type in ["animate"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
@ -4822,7 +4887,6 @@ def generate_video(
repeat_no = 0
extra_generation = 0
initial_total_windows = 0
discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length
if sliding_window:
@ -4843,7 +4907,7 @@ def generate_video(
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None
src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@ -4899,7 +4963,7 @@ def generate_video(
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
src_ref_images = image_refs
src_ref_images, src_ref_masks = image_refs, None
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
@ -4943,16 +5007,52 @@ def generate_video(
from models.wan.multitalk.multitalk import get_window_audio_embeddings
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
if vace:
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0:
frames_positions_list = []
if frames_positions is not None and len(frames_positions)> 0:
positions = frames_positions.replace(","," ").split(" ")
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count)
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
joker_used = False
project_window_no = 1
for pos in positions :
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
frames_positions_list.append(cur_end_pos)
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
else:
frames_positions_list.append(int(pos)-1 + alignment_shift)
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
if nb_frames_positions > 0:
frames_to_inject = [None] * (max(frames_positions_list) + 1)
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if video_guide is not None:
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
guide_frames_extract_count = len(keep_frames_parsed)
if vace:
if "B" in video_prompt_type:
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
if src_faces is not None and src_faces.shape[1] < current_video_length:
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
if vace or animate:
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
context_scale = [ control_net_weight]
if "V" in video_prompt_type:
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
@ -4971,10 +5071,10 @@ def generate_video(
if preprocess_type2 is not None:
context_scale = [ control_net_weight /2, control_net_weight2 /2]
send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
if preprocess_type2 != None:
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
if video_guide_processed != None:
if sample_fit_canvas != None:
@ -4985,7 +5085,37 @@ def generate_video(
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
if video_mask_processed != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
elif ltxv:
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
any_mask = True
any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
[video_mask_processed, video_mask_processed2],
pre_video_guide, image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
src_video, src_video2 = src_videos
src_mask, src_mask2 = src_masks
if src_video is None:
abort = True
break
if src_faces is not None:
if src_faces.shape[1] < src_video.shape[1]:
src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
src_faces = src_faces[:, :src_video.shape[1]]
if args.save_masks:
save_video( src_video, "masked_frames.mp4", fps)
if src_video2 is not None:
save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask:
save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
elif ltxv:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type]
send_cmd("progress", [0, get_latest_status(state, status_info)])
@ -5023,7 +5153,7 @@ def generate_video(
sample_fit_canvas = None
else: # video to video
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps)
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
if video_guide_processed is None:
src_video = pre_video_guide
else:
@ -5043,29 +5173,6 @@ def generate_video(
refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
if repeat_no == 1:
frames_positions_list = []
if frames_positions is not None and len(frames_positions)> 0:
positions = frames_positions.split(" ")
cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0
last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
joker_used = False
project_window_no = 1
for pos in positions :
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
if cur_end_pos >= last_frame_no and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
frames_positions_list.append(cur_end_pos)
cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
else:
frames_positions_list.append(int(pos)-1 + alignment_shift)
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from shared.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size
@ -5089,20 +5196,16 @@ def generate_video(
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref,
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
block_size=block_size,
outpainting_dims =outpainting_dims )
outpainting_dims =outpainting_dims,
background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
refresh_preview["image_refs"] = image_refs
if nb_frames_positions > 0:
frames_to_inject = [None] * (max(frames_positions_list) + 1)
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if vace :
frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame]
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
@ -5116,7 +5219,7 @@ def generate_video(
any_background_ref = any_background_ref
)
if len(frames_to_inject_parsed) or any_background_ref:
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
else:
@ -5165,10 +5268,14 @@ def generate_video(
input_prompt = prompt,
image_start = image_start_tensor,
image_end = image_end_tensor,
input_frames = src_video,
input_frames = src_video,
input_frames2 = src_video2,
input_ref_images= src_ref_images,
input_ref_masks = src_ref_masks,
input_masks = src_mask,
input_masks2 = src_mask2,
input_video= pre_video_guide,
input_faces = src_faces,
denoising_strength=denoising_strength,
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
frame_num= (current_video_length // latent_size)* latent_size + 1,
@ -5302,6 +5409,7 @@ def generate_video(
send_cmd("output")
else:
sample = samples.cpu()
abort = not is_image and sample.shape[1] < current_video_length
# if True: # for testing
# torch.save(sample, "output.pt")
# else:
@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV")
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
model_type = state["model_type"]
@ -7437,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
image_prompt_type_choices = []
if "T" in image_prompt_types_allowed:
image_prompt_type_choices += [("Text Prompt Only", "")]
image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")]
if "S" in image_prompt_types_allowed:
image_prompt_type_choices += [("Start Video with Image", "S")]
any_start_image = True
@ -7516,7 +7624,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ),
value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ),
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
)
any_control_video = True
@ -7560,13 +7668,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
}
mask_preprocessing_choices = []
mask_preprocessing_labels = guide_preprocessing.get("labels", {})
mask_preprocessing_labels = mask_preprocessing.get("labels", {})
for process_type in mask_preprocessing["selection"]:
process_label = mask_preprocessing_labels.get(process_type, None)
process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
mask_preprocessing_choices.append( (process_label, process_type) )
video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed")
video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed")
video_prompt_type_video_mask = gr.Dropdown(
mask_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
@ -7591,7 +7699,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices= image_ref_choices["choices"],
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
visible = image_ref_choices.get("visible", True),
label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2
label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2
)
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
@ -7634,7 +7742,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -7649,14 +7757,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "")
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects")
background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects")
remove_background_images_ref = gr.Dropdown(
choices=[
@ -7664,7 +7772,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
(background_removal_label, 1),
],
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
)
any_audio_voices_support = any_audio_track(base_model_type)
@ -8084,7 +8192,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Aligned to the beginning of the First Window of the new Video Sample", "T"),
],
value=filter_letters(video_prompt_type_value, "T"),
label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue",
label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue",
visible = vace or ltxv or t2v or infinitetalk
)