Merge branch 'main' into queue_editor_html_queue

This commit is contained in:
Chris Malone 2025-09-25 22:53:40 +10:00
commit a10d8f9004
17 changed files with 233 additions and 1324 deletions

View File

@ -20,15 +20,21 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates : ## 🔥 Latest Updates :
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! ### September 25 2025: WanGP v8.73 - Here Are ~~Two~~Three New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion. In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
Also because I wanted to spoil you:
- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
*Update 8.72*: shadow drop of Qwen Edit Plus
*Update 8.73*: Qwen Preview & InfiniteTalk Start image
### September 15 2025: WanGP v8.6 - Attack of the Clones ### September 15 2025: WanGP v8.6 - Attack of the Clones

View File

@ -0,0 +1,17 @@
{
"model": {
"name": "Qwen Image Edit Plus 20B",
"architecture": "qwen_image_edit_plus_20B",
"description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors"
],
"preload_URLs": "qwen_image_edit_20B",
"attention": {
"<89": "sdpa"
}
},
"prompt": "add a hat",
"resolution": "1024x1024",
"batch_size": 1
}

View File

@ -27,7 +27,7 @@ conda activate wan2gp
### Step 2: Install PyTorch ### Step 2: Install PyTorch
```shell ```shell
# Install PyTorch 2.7.0 with CUDA 12.4 # Install PyTorch 2.7.0 with CUDA 12.8
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
``` ```

View File

@ -28,7 +28,7 @@ class family_handler():
extra_model_def["any_image_refs_relative_size"] = True extra_model_def["any_image_refs_relative_size"] = True
extra_model_def["no_background_removal"] = True extra_model_def["no_background_removal"] = True
extra_model_def["image_ref_choices"] = { extra_model_def["image_ref_choices"] = {
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
("Up to two Images are Style Images", "KIJ")], ("Up to two Images are Style Images", "KIJ")],
"default": "KI", "default": "KI",
"letters_filter": "KIJ", "letters_filter": "KIJ",

View File

@ -173,8 +173,14 @@ class family_handler():
video_prompt_type = video_prompt_type.replace("M","") video_prompt_type = video_prompt_type.replace("M","")
ui_defaults["video_prompt_type"] = video_prompt_type ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.36:
if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
audio_prompt_type= ui_defaults["audio_prompt_type"]
if "A" not in audio_prompt_type:
audio_prompt_type += "A"
ui_defaults["audio_prompt_type"] = audio_prompt_type
pass
@staticmethod @staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults):
@ -197,6 +203,7 @@ class family_handler():
"guidance_scale": 7.5, "guidance_scale": 7.5,
"flow_shift": 13, "flow_shift": 13,
"video_prompt_type": "I", "video_prompt_type": "I",
"audio_prompt_type": "A",
}) })
elif base_model_type in ["hunyuan_custom_edit"]: elif base_model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({ ui_defaults.update({
@ -213,4 +220,5 @@ class family_handler():
"skip_steps_start_step_perc": 25, "skip_steps_start_step_perc": 25,
"video_length": 129, "video_length": 129,
"video_prompt_type": "KI", "video_prompt_type": "KI",
"audio_prompt_type": "A",
}) })

View File

@ -200,7 +200,8 @@ class QwenImagePipeline(): #DiffusionPipeline
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 1024 self.tokenizer_max_length = 1024
if processor is not None: if processor is not None:
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 64 self.prompt_template_encode_start_idx = 64
else: else:
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
@ -232,6 +233,21 @@ class QwenImagePipeline(): #DiffusionPipeline
txt = [template.format(e) for e in prompt] txt = [template.format(e) for e in prompt]
if self.processor is not None and image is not None: if self.processor is not None and image is not None:
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
if isinstance(image, list):
base_img_prompt = ""
for i, img in enumerate(image):
base_img_prompt += img_prompt_template.format(i + 1)
elif image is not None:
base_img_prompt = img_prompt_template.format(1)
else:
base_img_prompt = ""
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(base_img_prompt + e) for e in prompt]
model_inputs = self.processor( model_inputs = self.processor(
text=txt, text=txt,
images=image, images=image,
@ -464,7 +480,7 @@ class QwenImagePipeline(): #DiffusionPipeline
def prepare_latents( def prepare_latents(
self, self,
image, images,
batch_size, batch_size,
num_channels_latents, num_channels_latents,
height, height,
@ -482,24 +498,30 @@ class QwenImagePipeline(): #DiffusionPipeline
shape = (batch_size, num_channels_latents, 1, height, width) shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None image_latents = None
if image is not None: if images is not None and len(images ) > 0:
image = image.to(device=device, dtype=dtype) if not isinstance(images, list):
if image.shape[1] != self.latent_channels: images = [images]
image_latents = self._encode_vae_image(image=image, generator=generator) all_image_latents = []
else: for image in images:
image_latents = image image = image.to(device=device, dtype=dtype)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: if image.shape[1] != self.latent_channels:
# expand init_latents for batch_size image_latents = self._encode_vae_image(image=image, generator=generator)
additional_image_per_prompt = batch_size // image_latents.shape[0] else:
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) image_latents = image
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
raise ValueError( # expand init_latents for batch_size
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." additional_image_per_prompt = batch_size // image_latents.shape[0]
) image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
else: elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
image_latents = torch.cat([image_latents], dim=0) raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
image_latents = self._pack_latents(image_latents) image_latents = self._pack_latents(image_latents)
all_image_latents.append(image_latents)
image_latents = torch.cat(all_image_latents, dim=1)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
@ -568,6 +590,7 @@ class QwenImagePipeline(): #DiffusionPipeline
joint_pass= True, joint_pass= True,
lora_inpaint = False, lora_inpaint = False,
outpainting_dims = None, outpainting_dims = None,
qwen_edit_plus = False,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@ -683,61 +706,54 @@ class QwenImagePipeline(): #DiffusionPipeline
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = "cuda" device = "cuda"
prompt_image = None condition_images = []
vae_image_sizes = []
vae_images = []
image_mask_latents = None image_mask_latents = None
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): ref_size = 1024
image = image[0] if isinstance(image, list) else image ref_text_encoder_size = 384 if qwen_edit_plus else 1024
image_height, image_width = self.image_processor.get_default_height_width(image) if image is not None:
aspect_ratio = image_width / image_height if not isinstance(image, list): image = [image]
if False : if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width))
_, image_width, image_height = min( for ref_no, img in enumerate(image):
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS image_width, image_height = img.size
) any_mask = ref_no == 0 and image_mask is not None
image_width = image_width // multiple_of * multiple_of if (image_height * image_width > ref_size * ref_size) and not any_mask:
image_height = image_height // multiple_of * multiple_of vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of)
ref_height, ref_width = 1568, 672 else:
vae_height, vae_width = image_height, image_width
vae_width = vae_width // multiple_of * multiple_of
vae_height = vae_height // multiple_of * multiple_of
vae_image_sizes.append((vae_width, vae_height))
condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of)
condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) )
if img.size != (vae_width, vae_height):
img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS)
if any_mask :
if lora_inpaint:
image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
img = convert_image_to_tensor(img)
green = torch.tensor([-1.0, 1.0, -1.0]).to(img)
green_image = green[:, None, None] .expand_as(img)
img = torch.where(image_mask_rebuilt > 0, green_image, img)
img = convert_tensor_to_image(img)
else:
image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
image_mask_latents = self._pack_latents(image_mask_latents)
# img.save("nnn.png")
vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) )
if image_mask is None:
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
if image_height * image_width > ref_height * ref_width:
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
if (image_width,image_height) != image.size:
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
elif not lora_inpaint:
# _, image_width, image_height = min(
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
# )
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
height, width = image_height, image_width
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
image_mask_latents = self._pack_latents(image_mask_latents)
prompt_image = image
if image.size != (image_width, image_height):
image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
image = convert_image_to_tensor(image)
if lora_inpaint:
image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
image_mask_latents = None
green = torch.tensor([-1.0, 1.0, -1.0]).to(image)
green_image = green[:, None, None] .expand_as(image)
image = torch.where(image_mask_rebuilt > 0, green_image, image)
prompt_image = convert_tensor_to_image(image)
image = image.unsqueeze(0).unsqueeze(2)
# image.save("nnn.png")
has_neg_prompt = negative_prompt is not None or ( has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
) )
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt( prompt_embeds, prompt_embeds_mask = self.encode_prompt(
image=prompt_image, image=condition_images,
prompt=prompt, prompt=prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask, prompt_embeds_mask=prompt_embeds_mask,
@ -747,7 +763,7 @@ class QwenImagePipeline(): #DiffusionPipeline
) )
if do_true_cfg: if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
image=prompt_image, image=condition_images,
prompt=negative_prompt, prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds, prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask, prompt_embeds_mask=negative_prompt_embeds_mask,
@ -763,7 +779,7 @@ class QwenImagePipeline(): #DiffusionPipeline
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels // 4 num_channels_latents = self.transformer.in_channels // 4
latents, image_latents = self.prepare_latents( latents, image_latents = self.prepare_latents(
image, vae_images,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height, height,
@ -779,7 +795,12 @@ class QwenImagePipeline(): #DiffusionPipeline
img_shapes = [ img_shapes = [
[ [
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
(1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
*[
(1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
for vae_width, vae_height in vae_image_sizes
],
] ]
] * batch_size ] * batch_size
else: else:
@ -950,8 +971,9 @@ class QwenImagePipeline(): #DiffusionPipeline
latents = latents.to(latents_dtype) latents = latents.to(latents_dtype)
if callback is not None: if callback is not None:
# preview = unpack_latent(img).transpose(0,1) preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
callback(i, None, False) preview = preview.squeeze(0)
callback(i, preview, False)
self._current_timestep = None self._current_timestep = None
@ -971,7 +993,7 @@ class QwenImagePipeline(): #DiffusionPipeline
latents = latents / latents_std + latents_mean latents = latents / latents_std + latents_mean
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None): if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None):
output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt output_image = vae_images[0].squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(vae_images[0] ) * image_mask_rebuilt
return output_image return output_image

View File

@ -20,7 +20,7 @@ class family_handler():
"fit_into_canvas_image_refs": 0, "fit_into_canvas_image_refs": 0,
} }
if base_model_type in ["qwen_image_edit_20B"]: if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
extra_model_def["inpaint_support"] = True extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = { extra_model_def["image_ref_choices"] = {
"choices": [ "choices": [
@ -42,11 +42,20 @@ class family_handler():
"image_modes" : [2], "image_modes" : [2],
} }
if base_model_type in ["qwen_image_edit_plus_20B"]:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "PV", "SV", "CV"],
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A"],
"visible": False,
}
return extra_model_def return extra_model_def
@staticmethod @staticmethod
def query_supported_types(): def query_supported_types():
return ["qwen_image_20B", "qwen_image_edit_20B"] return ["qwen_image_20B", "qwen_image_edit_20B", "qwen_image_edit_plus_20B"]
@staticmethod @staticmethod
def query_family_maps(): def query_family_maps():
@ -113,9 +122,16 @@ class family_handler():
"denoising_strength" : 1., "denoising_strength" : 1.,
"model_mode" : 0, "model_mode" : 0,
}) })
elif base_model_type in ["qwen_image_edit_plus_20B"]:
ui_defaults.update({
"video_prompt_type": "I",
"denoising_strength" : 1.,
"model_mode" : 0,
})
@staticmethod
def validate_generative_settings(base_model_type, model_def, inputs): def validate_generative_settings(base_model_type, model_def, inputs):
if base_model_type in ["qwen_image_edit_20B"]: if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
model_mode = inputs["model_mode"] model_mode = inputs["model_mode"]
denoising_strength= inputs["denoising_strength"] denoising_strength= inputs["denoising_strength"]
video_guide_outpainting= inputs["video_guide_outpainting"] video_guide_outpainting= inputs["video_guide_outpainting"]
@ -126,3 +142,9 @@ class family_handler():
gr.Info("Denoising Strength will be ignored while using Lora Inpainting") gr.Info("Denoising Strength will be ignored while using Lora Inpainting")
if outpainting_dims is not None and model_mode == 0 : if outpainting_dims is not None and model_mode == 0 :
return "Outpainting is not supported with Masked Denoising " return "Outpainting is not supported with Masked Denoising "
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("qwen")
return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -51,10 +51,10 @@ class model_factory():
transformer_filename = model_filename[0] transformer_filename = model_filename[0]
processor = None processor = None
tokenizer = None tokenizer = None
if base_model_type == "qwen_image_edit_20B": if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
self.base_model_type = base_model_type
base_config_file = "configs/qwen_image_20B.json" base_config_file = "configs/qwen_image_20B.json"
with open(base_config_file, 'r', encoding='utf-8') as f: with open(base_config_file, 'r', encoding='utf-8') as f:
@ -173,7 +173,7 @@ class model_factory():
self.vae.tile_latent_min_height = VAE_tile_size[1] self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1] self.vae.tile_latent_min_width = VAE_tile_size[1]
qwen_edit_plus = self.base_model_type in ["qwen_image_edit_plus_20B"]
self.vae.enable_slicing() self.vae.enable_slicing()
# width, height = aspect_ratios["16:9"] # width, height = aspect_ratios["16:9"]
@ -182,17 +182,19 @@ class model_factory():
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if input_frames is not None: if input_frames is not None:
input_ref_images = [convert_tensor_to_image(input_frames) ] input_ref_images = [convert_tensor_to_image(input_frames) ] + ([] if input_ref_images is None else input_ref_images )
elif input_ref_images is not None:
if input_ref_images is not None:
# image stiching method # image stiching method
stiched = input_ref_images[0] stiched = input_ref_images[0]
if "K" in video_prompt_type : if "K" in video_prompt_type :
w, h = input_ref_images[0].size w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
for new_img in input_ref_images[1:]: if not qwen_edit_plus:
stiched = stitch_images(stiched, new_img) for new_img in input_ref_images[1:]:
input_ref_images = [stiched] stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
image = self.pipeline( image = self.pipeline(
prompt=input_prompt, prompt=input_prompt,
@ -212,7 +214,8 @@ class model_factory():
generator=torch.Generator(device="cuda").manual_seed(seed), generator=torch.Generator(device="cuda").manual_seed(seed),
lora_inpaint = image_mask is not None and model_mode == 1, lora_inpaint = image_mask is not None and model_mode == 1,
outpainting_dims = outpainting_dims, outpainting_dims = outpainting_dims,
) qwen_edit_plus = qwen_edit_plus,
)
if image is None: return None if image is None: return None
return image.transpose(0, 1) return image.transpose(0, 1)

View File

@ -443,38 +443,32 @@ class WanAny2V:
# image2video # image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False any_end_frame = False
if image_start is None: if infinitetalk:
if infinitetalk: new_shot = "Q" in video_prompt_type
new_shot = "Q" in video_prompt_type if input_frames is not None:
if input_frames is not None: image_ref = input_frames[:, 0]
image_ref = input_frames[:, 0]
else:
if input_ref_images is None:
if pre_video_frame is None: raise Exception("Missing Reference Image")
input_ref_images, new_shot = [pre_video_frame], False
new_shot = new_shot and window_no <= len(input_ref_images)
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot or input_video is None:
input_video = image_ref.unsqueeze(1)
else:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
image_start = image_ref.to(input_video)
control_pre_frames_count = 1
control_video = image_start.unsqueeze(1)
else: else:
image_start = input_video[:, -1] if input_ref_images is None:
control_pre_frames_count = preframes_count if pre_video_frame is None: raise Exception("Missing Reference Image")
control_video = input_video input_ref_images, new_shot = [pre_video_frame], False
new_shot = new_shot and window_no <= len(input_ref_images)
color_reference_frame = image_start.unsqueeze(1).clone() image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot or input_video is None:
input_video = image_ref.unsqueeze(1)
else:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
image_start = image_ref.to(input_video)
control_pre_frames_count = 1
control_video = image_start.unsqueeze(1)
else: else:
preframes_count = control_pre_frames_count = 1 image_start = input_video[:, -1]
height, width = image_start.shape[1:] control_pre_frames_count = preframes_count
control_video = image_start.unsqueeze(1).to(self.device) control_video = input_video
color_reference_frame = control_video.clone()
color_reference_frame = image_start.unsqueeze(1).clone()
any_end_frame = image_end is not None any_end_frame = image_end is not None
add_frames_for_end_image = any_end_frame and model_type == "i2v" add_frames_for_end_image = any_end_frame and model_type == "i2v"

View File

@ -1,479 +0,0 @@
import math
import os
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import numpy as np
import torch
from diffusers.image_processor import PipelineImageInput
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from tqdm import tqdm
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class DTT2V:
def __init__(
self,
config,
checkpoint_dir,
rank=0,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16,
):
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json")
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
self.scheduler = FlowUniPCMultistepScheduler()
@property
def do_classifier_free_guidance(self) -> bool:
return self._guidance_scale > 1
def encode_image(
self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# prefix_video
prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
if prefix_video.dtype == torch.uint8:
prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
prefix_video = prefix_video.to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
return prefix_video, predix_video_latent_length
def prepare_latents(
self,
shape: Tuple[int],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
) -> torch.Tensor:
return randn_tensor(shape, generator, device=device, dtype=dtype)
def generate_timestep_matrix(
self,
num_frames,
step_template,
base_num_frames,
ar_step=5,
num_pre_ready=0,
casual_block_size=1,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while torch.all(pre_row >= (num_iterations - 1)) == False:
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (
num_iterations - 1
): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append(
(new_row != pre_row) & (new_row != num_iterations)
) # False: no need to update True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
return step_matrix, step_index, step_update_mask, valid_interval
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]] = "",
image: PipelineImageInput = None,
height: int = 480,
width: int = 832,
num_frames: int = 97,
num_inference_steps: int = 50,
shift: float = 1.0,
guidance_scale: float = 5.0,
seed: float = 0.0,
overlap_history: int = 17,
addnoise_condition: int = 0,
base_num_frames: int = 97,
ar_step: int = 5,
causal_block_size: int = 1,
causal_attention: bool = False,
fps: int = 24,
VAE_tile_size = 0,
joint_pass = False,
callback = None,
):
generator = torch.Generator(device=self.device)
generator.manual_seed(seed)
# if base_num_frames > base_num_frames:
# causal_block_size = 0
self._guidance_scale = guidance_scale
i2v_extra_kwrags = {}
prefix_video = None
predix_video_latent_length = 0
if image:
frame_width, frame_height = image.size
scale = min(height / frame_height, width / frame_width)
height = (int(frame_height * scale) // 16) * 16
width = (int(frame_width * scale) // 16) * 16
prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size)
latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8
prompt_embeds = self.text_encoder([prompt], self.device)
prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
if self.do_classifier_free_guidance:
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
init_timesteps = self.scheduler.timesteps
fps_embeds = [fps] * prompt_embeds[0].shape[0]
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
transformer_dtype = self.dtype
# with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad():
if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
# short video generation
latent_shape = [16, latent_length, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
latents = [latents]
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
)
sample_schedulers = []
for _ in range(latent_length):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * latent_length
if callback != None:
callback(-1, None, True)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
+ torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
# "causal_block_size" : causal_block_size,
"callback" : callback,
"pipeline" : self
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=prompt_embeds,
context2=negative_prompt_embeds,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=negative_prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0], False)
x0 = latents[0].unsqueeze(0)
videos = self.vae.decode(x0, tile_size= VAE_tile_size)
videos = (videos / 2 + 0.5).clamp(0, 1)
videos = [video for video in videos]
videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
return videos
else:
# long video generation
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
overlap_history_frames = (overlap_history - 1) // 4 + 1
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
print(f"n_iter:{n_iter}")
output_video = None
for i in range(n_iter):
if output_video is not None: # i !=0
prefix_video = output_video[:, -overlap_history:].to(self.device)
prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
if prefix_video[0].shape[1] % causal_block_size != 0:
truncate_len = prefix_video[0].shape[1] % causal_block_size
print("the length of prefix video is truncated for the casual block size alignment.")
prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
predix_video_latent_length = prefix_video[0].shape[1]
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
else: # i == 0
base_num_frames_iter = base_num_frames
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
latents = self.prepare_latents(
latent_shape, dtype=torch.float32, device=self.device, generator=generator
)
latents = [latents]
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
base_num_frames_iter,
ar_step,
predix_video_latent_length,
causal_block_size,
)
sample_schedulers = []
for _ in range(base_num_frames_iter):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * base_num_frames_iter
if callback != None:
callback(-1, None, True)
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=prompt_embeds,
context2=negative_prompt_embeds,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=negative_prompt_embeds,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=generator,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0].squeeze(0), False)
x0 = latents[0].unsqueeze(0)
videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
if output_video is None:
output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
else:
output_video = torch.cat(
[output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
) # c, f, h, w
return output_video

View File

@ -1,698 +0,0 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
from mmgp import offload
import torch
import torch.nn as nn
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
def optimized_scale(positive_flat, negative_flat):
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
rank=0,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
):
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=text_encoder_filename,
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn= None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
self.sample_neg_prompt = config.sample_neg_prompt
if "Vace" in model_filename:
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=480*832,
max_area=480*832,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
self.adapt_vace_model()
self.scheduler = FlowUniPCMultistepScheduler()
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
else:
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
else:
ref_latent = self.vae.encode(refs, tile_size = tile_size)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.vae_stride[0])
height = 2 * (int(height) // (self.vae_stride[1] * 2))
width = 2 * (int(width) // (self.vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, self.vae_stride[1], width, self.vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
self.vae_stride[1] * self.vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
image_sizes = []
trim_video = len(keep_frames)
for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
num_frames = total_frames - prepend_count
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
else:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
if prepend_count > 0:
src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
src_video_shape = src_video[i].shape
if src_video_shape[1] != total_frames:
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
image_sizes.append(src_video[i].shape[2:])
for k, keep in enumerate(keep_frames):
if not keep:
src_video[i][:, k:k+1] = 0
src_mask[i][:, k:k+1] = 1
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return self.vae.decode(trimed_zs, tile_size= tile_size)
def generate_timestep_matrix(
self,
num_frames,
step_template,
base_num_frames,
ar_step=5,
num_pre_ready=0,
casual_block_size=1,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while torch.all(pre_row >= (num_iterations - 1)) == False:
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (
num_iterations - 1
): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append(
(new_row != pre_row) & (new_row != num_iterations)
) # False: no need to update True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
return step_matrix, step_index, step_update_mask, valid_interval
def generate(self,
input_prompt,
input_frames= None,
input_masks = None,
input_ref_images = None,
source_video=None,
target_camera=None,
context_scale=1.0,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = None,
VAE_tile_size = 0,
joint_pass = False,
slg_layers = None,
slg_start = 0.0,
slg_end = 1.0,
cfg_star_switch = True,
cfg_zero_step = 5,
):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
num_frames = frame_num
addnoise_condition = 20
causal_attention = True
fps = 16
ar_step = 5
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if target_camera != None:
size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device)
source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
del source_video
# Process target camera (recammaster)
from wan.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
if input_frames != None:
# vace context encode
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
else:
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1])
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
# evaluation mode
# if sample_solver == 'unipc':
# sample_scheduler = FlowUniPCMultistepScheduler(
# num_train_timesteps=self.num_train_timesteps,
# shift=1,
# use_dynamic_shifting=False)
# sample_scheduler.set_timesteps(
# sampling_steps, device=self.device, shift=shift)
# timesteps = sample_scheduler.timesteps
# elif sample_solver == 'dpm++':
# sample_scheduler = FlowDPMSolverMultistepScheduler(
# num_train_timesteps=self.num_train_timesteps,
# shift=1,
# use_dynamic_shifting=False)
# sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
# timesteps, _ = retrieve_timesteps(
# sample_scheduler,
# device=self.device,
# sigmas=sampling_sigmas)
# else:
# raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
del noise
batch_size =len(latents)
if target_camera != None:
shape = list(latents[0].shape[1:])
shape[0] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
# arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
# arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
# arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
i2v_extra_kwrags = {}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
i2v_extra_kwrags.update(recam_dict)
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
i2v_extra_kwrags.update(vace_dict)
latent_length = (num_frames - 1) // 4 + 1
latent_height = height // 8
latent_width = width // 8
if ar_step == 0:
causal_block_size = 1
fps_embeds = [fps] #* prompt_embeds[0].shape[0]
fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
init_timesteps = self.scheduler.timesteps
base_num_frames_iter = latent_length
latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
prefix_video = None
predix_video_latent_length = 0
if prefix_video is not None:
latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
base_num_frames_iter,
init_timesteps,
base_num_frames_iter,
ar_step,
predix_video_latent_length,
causal_block_size,
)
sample_schedulers = []
for _ in range(base_num_frames_iter):
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
sample_schedulers.append(sample_scheduler)
sample_schedulers_counter = [0] * base_num_frames_iter
updated_num_steps= len(step_matrix)
if callback != None:
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
# if callback != None:
# callback(-1, None, True)
for i, timestep_i in enumerate(tqdm(step_matrix)):
update_mask_i = step_update_mask[i]
valid_interval_i = valid_interval[i]
valid_interval_start, valid_interval_end = valid_interval_i
timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
noise_factor = 0.001 * addnoise_condition
timestep_for_noised_condition = addnoise_condition
latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
* (1.0 - noise_factor)
+ torch.randn_like(
latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
)
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self,
"current_step" : i,
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=context,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=context,
context2=context_null,
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=context,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=context_null,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(
noise_pred[:, idx - valid_interval_start],
timestep_i[idx],
latents[0][:, idx],
return_dict=False,
generator=seed_g,
)[0]
sample_schedulers_counter[idx] += 1
if callback is not None:
callback(i, latents[0].squeeze(0), False)
# for i, t in enumerate(tqdm(timesteps)):
# if target_camera != None:
# latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
# else:
# latent_model_input = latents
# slg_layers_local = None
# if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
# slg_layers_local = slg_layers
# timestep = [t]
# offload.set_step_no_for_lora(self.model, i)
# timestep = torch.stack(timestep)
# if joint_pass:
# noise_pred_cond, noise_pred_uncond = self.model(
# latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
# if self._interrupt:
# return None
# else:
# noise_pred_cond = self.model(
# latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
# if self._interrupt:
# return None
# noise_pred_uncond = self.model(
# latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
# if self._interrupt:
# return None
# # del latent_model_input
# # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
# noise_pred_text = noise_pred_cond
# if cfg_star_switch:
# positive_flat = noise_pred_text.view(batch_size, -1)
# negative_flat = noise_pred_uncond.view(batch_size, -1)
# alpha = optimized_scale(positive_flat,negative_flat)
# alpha = alpha.view(batch_size, 1, 1, 1)
# if (i <= cfg_zero_step):
# noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
# else:
# noise_pred_uncond *= alpha
# noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# temp_x0 = sample_scheduler.step(
# noise_pred[:, :target_shape[1]].unsqueeze(0),
# t,
# latents[0].unsqueeze(0),
# return_dict=False,
# generator=seed_g)[0]
# latents = [temp_x0.squeeze(0)]
# del temp_x0
# if callback is not None:
# callback(i, latents[0], False)
x0 = latents
if input_frames == None:
videos = self.vae.decode(x0, VAE_tile_size)
else:
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del latents
del sample_scheduler
return videos[0] if self.rank == 0 else None
def adapt_vace_model(self):
model = self.model
modules_dict= { k: m for k, m in model.named_modules()}
for model_layer, vace_layer in model.vace_layers_mapping.items():
module = modules_dict[f"vace_blocks.{vace_layer}"]
target = modules_dict[f"blocks.{model_layer}"]
setattr(target, "vace", module )
delattr(model, "vace_blocks")

View File

@ -245,8 +245,10 @@ class family_handler():
"visible" : False, "visible" : False,
} }
if vace_class or base_model_type in ["infinitetalk", "animate"]: if vace_class or base_model_type in ["animate"]:
image_prompt_types_allowed = "TVL" image_prompt_types_allowed = "TVL"
elif base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TSVL"
elif base_model_type in ["ti2v_2_2"]: elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL" image_prompt_types_allowed = "TSVL"
elif base_model_type in ["lucy_edit"]: elif base_model_type in ["lucy_edit"]:

View File

@ -21,6 +21,7 @@ from .utils.get_default_model import get_matanyone_model
from .matanyone.inference.inference_core import InferenceCore from .matanyone.inference.inference_core import InferenceCore
from .matanyone_wrapper import matanyone from .matanyone_wrapper import matanyone
from shared.utils.audio_video import save_video, save_image from shared.utils.audio_video import save_video, save_image
from mmgp import offload
arg_device = "cuda" arg_device = "cuda"
arg_sam_model_type="vit_h" arg_sam_model_type="vit_h"
@ -539,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive
file_name = ".".join(file_name.split(".")[:-1]) file_name = ".".join(file_name.split(".")[:-1])
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel )
output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" output_fg_path = f"./mask_outputs/{file_name}_fg.mp4"
output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4"
if len(source_audio_tracks) == 0: if len(source_audio_tracks) == 0:
@ -679,7 +680,6 @@ def load_unload_models(selected):
} }
# os.path.join('.') # os.path.join('.')
from mmgp import offload
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None sam_checkpoint = None

View File

@ -52,7 +52,7 @@ matplotlib
# Utilities # Utilities
ftfy ftfy
piexif piexif
pynvml nvidia-ml-py
misaki misaki
# Optional / commented out # Optional / commented out

View File

@ -1,6 +1,6 @@
# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) # thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py)
def get_rgb_factors(model_family, model_type = None): def get_rgb_factors(model_family, model_type = None):
if model_family == "wan": if model_family in ["wan", "qwen"]:
if model_type =="ti2v_2_2": if model_type =="ti2v_2_2":
latent_channels = 48 latent_channels = 48
latent_dimensions = 3 latent_dimensions = 3
@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None):
[ 0.0249, -0.0469, -0.1703] [ 0.0249, -0.0469, -0.1703]
] ]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
else: else:
latent_rgb_factors_bias = latent_rgb_factors = None latent_rgb_factors_bias = latent_rgb_factors = None
return latent_rgb_factors, latent_rgb_factors_bias return latent_rgb_factors, latent_rgb_factors_bias

View File

@ -321,7 +321,7 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
ref_width, ref_height = ref_img.size ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None: if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
canvas = torch.zeros_like(ref_img) if return_mask else None canvas = torch.zeros_like(ref_img[:1]) if return_mask else None
else: else:
if outpainting_dims != None: if outpainting_dims != None:
final_height, final_width = image_size final_height, final_width = image_size
@ -374,7 +374,7 @@ def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, im
if pre_video_guide is not None: if pre_video_guide is not None:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask: if any_mask:
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1)
if any_guide_padding: if any_guide_padding:
if src_video is None: if src_video is None:

48
wgp.py
View File

@ -63,8 +63,8 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10 PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0" target_mmgp_version = "3.6.0"
WanGP_version = "8.71" WanGP_version = "8.73"
settings_version = 2.35 settings_version = 2.36
max_source_video_frames = 3000 max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -5041,7 +5041,7 @@ def generate_video(
if repeat_no >= total_generation: break if repeat_no >= total_generation: break
repeat_no +=1 repeat_no +=1
gen["repeat_no"] = repeat_no gen["repeat_no"] = repeat_no
src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None
prefix_video = pre_video_frame = None prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@ -5169,7 +5169,7 @@ def generate_video(
frames_to_inject[pos] = image_refs[i] frames_to_inject[pos] = image_refs[i]
video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None
if video_guide is not None: if video_guide is not None:
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0: if len(error) > 0:
@ -5259,7 +5259,7 @@ def generate_video(
any_guide_padding = model_def.get("pad_guide_video", False) any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
[video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]), [video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]),
None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
image_size, current_video_length, latent_size, image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color, any_mask, any_guide_padding, guide_inpaint_color,
@ -5281,9 +5281,12 @@ def generate_video(
src_faces = src_faces[:, :src_video.shape[1]] src_faces = src_faces[:, :src_video.shape[1]]
if video_guide is not None or len(frames_to_inject_parsed) > 0: if video_guide is not None or len(frames_to_inject_parsed) > 0:
if args.save_masks: if args.save_masks:
if src_video is not None: save_video( src_video, "masked_frames.mp4", fps) if src_video is not None:
if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps) save_video( src_video, "masked_frames.mp4", fps)
if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
if src_video2 is not None:
save_video( src_video2, "masked_frames2.mp4", fps)
if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1))
if video_guide is not None: if video_guide is not None:
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
@ -6766,11 +6769,11 @@ def switch_image_mode(state):
inpaint_support = model_def.get("inpaint_support", False) inpaint_support = model_def.get("inpaint_support", False)
if inpaint_support: if inpaint_support:
if image_mode == 1: if image_mode == 1:
video_prompt_type = del_in_sequence(video_prompt_type, "VAG") video_prompt_type = del_in_sequence(video_prompt_type, "VAG" + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, "KI") video_prompt_type = add_to_sequence(video_prompt_type, "KI")
elif image_mode == 2: elif image_mode == 2:
video_prompt_type = del_in_sequence(video_prompt_type, "KI" + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, "VAG") video_prompt_type = add_to_sequence(video_prompt_type, "VAG")
video_prompt_type = del_in_sequence(video_prompt_type, "KI")
ui_defaults["video_prompt_type"] = video_prompt_type ui_defaults["video_prompt_type"] = video_prompt_type
return str(time.time()) return str(time.time())
@ -7156,10 +7159,11 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
return video_prompt_type return video_prompt_type
all_guide_processes ="PDESLCMUVB"
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type old_video_prompt_type = video_prompt_type
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB") video_prompt_type = del_in_sequence(video_prompt_type, all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type visible = "V" in video_prompt_type
model_type = state["model_type"] model_type = state["model_type"]
@ -7169,8 +7173,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
image_outputs = image_mode > 0 image_outputs = image_mode > 0
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
mask_preprocessing = model_def.get("mask_preprocessing", None)
return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) if mask_preprocessing is not None:
mask_selector_visible = mask_preprocessing.get("visible", True)
else:
mask_selector_visible = True
return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode): def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode):
@ -7662,8 +7670,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
any_control_video = any_control_image = False any_control_video = any_control_image = False
guide_preprocessing = model_def.get("guide_preprocessing", None) if image_mode_value ==2:
mask_preprocessing = model_def.get("mask_preprocessing", None) guide_preprocessing = { "selection": ["V", "VG"]}
mask_preprocessing = { "selection": ["A"]}
else:
guide_preprocessing = model_def.get("guide_preprocessing", None)
mask_preprocessing = model_def.get("mask_preprocessing", None)
guide_custom_choices = model_def.get("guide_custom_choices", None) guide_custom_choices = model_def.get("guide_custom_choices", None)
image_ref_choices = model_def.get("image_ref_choices", None) image_ref_choices = model_def.get("image_ref_choices", None)
@ -7707,7 +7719,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown( video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices, guide_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ), value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ),
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
) )
any_control_video = True any_control_video = True
@ -7793,8 +7805,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if image_guide_value is None: if image_guide_value is None:
image_mask_guide_value = None image_mask_guide_value = None
else: else:
image_mask_value = rgb_bw_to_rgba_mask(image_mask_value) image_mask_guide_value = { "background" : image_guide_value, "composite" : None}
image_mask_guide_value = { "background" : image_guide_value, "composite" : None, "layers": [image_mask_value] } image_mask_guide_value["layers"] = [] if image_mask_value is None else [rgb_bw_to_rgba_mask(image_mask_value)]
image_mask_guide = gr.ImageEditor( image_mask_guide = gr.ImageEditor(
label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask", label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask",