mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
righted a wrong
This commit is contained in:
parent
960f1f87c1
commit
27d4c8eb4d
@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||||
|
|
||||||
## 🔥 Latest Updates :
|
## 🔥 Latest Updates :
|
||||||
|
### August 8 2025: WanGP v7.72 - Qwen Rebirth
|
||||||
|
Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
|
||||||
|
|
||||||
|
As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
|
||||||
|
|
||||||
|
Hopefully this new release solves as well the Sage/Sage2 blacscreen on some GPUs.
|
||||||
|
|
||||||
|
|
||||||
### August 6 2025: WanGP v7.71 - Picky, picky
|
### August 6 2025: WanGP v7.71 - Picky, picky
|
||||||
|
|
||||||
This release comes with two new models :
|
This release comes with two new models :
|
||||||
|
|||||||
682
i2v_inference.py
682
i2v_inference.py
@ -1,682 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
import gc
|
|
||||||
import random
|
|
||||||
|
|
||||||
# These imports rely on your existing code structure
|
|
||||||
# They must match the location of your WAN code, etc.
|
|
||||||
import wan
|
|
||||||
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
|
||||||
from wan.modules.attention import get_attention_modes
|
|
||||||
from wan.utils.utils import cache_video
|
|
||||||
from mmgp import offload, safetensors2, profile_type
|
|
||||||
|
|
||||||
try:
|
|
||||||
import triton
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
DATA_DIR = "ckpts"
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
|
||||||
# HELPER FUNCTIONS
|
|
||||||
# --------------------------------------------------
|
|
||||||
|
|
||||||
def sanitize_file_name(file_name):
|
|
||||||
"""Clean up file name from special chars."""
|
|
||||||
return (
|
|
||||||
file_name.replace("/", "")
|
|
||||||
.replace("\\", "")
|
|
||||||
.replace(":", "")
|
|
||||||
.replace("|", "")
|
|
||||||
.replace("?", "")
|
|
||||||
.replace("<", "")
|
|
||||||
.replace(">", "")
|
|
||||||
.replace('"', "")
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract_preset(lset_name, lora_dir, loras):
|
|
||||||
"""
|
|
||||||
Load a .lset JSON that lists the LoRA files to apply, plus multipliers
|
|
||||||
and possibly a suggested prompt prefix.
|
|
||||||
"""
|
|
||||||
lset_name = sanitize_file_name(lset_name)
|
|
||||||
if not lset_name.endswith(".lset"):
|
|
||||||
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
|
|
||||||
else:
|
|
||||||
lset_name_filename = os.path.join(lora_dir, lset_name)
|
|
||||||
|
|
||||||
if not os.path.isfile(lset_name_filename):
|
|
||||||
raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")
|
|
||||||
|
|
||||||
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
|
||||||
text = reader.read()
|
|
||||||
lset = json.loads(text)
|
|
||||||
|
|
||||||
loras_choices_files = lset["loras"]
|
|
||||||
loras_choices = []
|
|
||||||
missing_loras = []
|
|
||||||
for lora_file in loras_choices_files:
|
|
||||||
# Build absolute path and see if it is in loras
|
|
||||||
full_lora_path = os.path.join(lora_dir, lora_file)
|
|
||||||
if full_lora_path in loras:
|
|
||||||
idx = loras.index(full_lora_path)
|
|
||||||
loras_choices.append(str(idx))
|
|
||||||
else:
|
|
||||||
missing_loras.append(lora_file)
|
|
||||||
|
|
||||||
if len(missing_loras) > 0:
|
|
||||||
missing_list = ", ".join(missing_loras)
|
|
||||||
raise ValueError(f"Missing LoRA files for preset: {missing_list}")
|
|
||||||
|
|
||||||
loras_mult_choices = lset["loras_mult"]
|
|
||||||
prompt_prefix = lset.get("prompt", "")
|
|
||||||
full_prompt = lset.get("full_prompt", False)
|
|
||||||
return loras_choices, loras_mult_choices, prompt_prefix, full_prompt
|
|
||||||
|
|
||||||
def get_attention_mode(args_attention, installed_modes):
|
|
||||||
"""
|
|
||||||
Decide which attention mode to use: either the user choice or auto fallback.
|
|
||||||
"""
|
|
||||||
if args_attention == "auto":
|
|
||||||
for candidate in ["sage2", "sage", "sdpa"]:
|
|
||||||
if candidate in installed_modes:
|
|
||||||
return candidate
|
|
||||||
return "sdpa" # last fallback
|
|
||||||
elif args_attention in installed_modes:
|
|
||||||
return args_attention
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Requested attention mode '{args_attention}' not installed. "
|
|
||||||
f"Installed modes: {installed_modes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_i2v_model(model_filename, text_encoder_filename, is_720p):
|
|
||||||
"""
|
|
||||||
Load the i2v model with a specific size config and text encoder.
|
|
||||||
"""
|
|
||||||
if is_720p:
|
|
||||||
print("Loading 14B-720p i2v model ...")
|
|
||||||
cfg = WAN_CONFIGS['i2v-14B']
|
|
||||||
wan_model = wan.WanI2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=DATA_DIR,
|
|
||||||
model_filename=model_filename,
|
|
||||||
text_encoder_filename=text_encoder_filename
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print("Loading 14B-480p i2v model ...")
|
|
||||||
cfg = WAN_CONFIGS['i2v-14B']
|
|
||||||
wan_model = wan.WanI2V(
|
|
||||||
config=cfg,
|
|
||||||
checkpoint_dir=DATA_DIR,
|
|
||||||
model_filename=model_filename,
|
|
||||||
text_encoder_filename=text_encoder_filename
|
|
||||||
)
|
|
||||||
# Pipe structure
|
|
||||||
pipe = {
|
|
||||||
"transformer": wan_model.model,
|
|
||||||
"text_encoder": wan_model.text_encoder.model,
|
|
||||||
"text_encoder_2": wan_model.clip.model,
|
|
||||||
"vae": wan_model.vae.model
|
|
||||||
}
|
|
||||||
return wan_model, pipe
|
|
||||||
|
|
||||||
def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
|
|
||||||
"""
|
|
||||||
Load loras from a directory, optionally apply a preset.
|
|
||||||
"""
|
|
||||||
from pathlib import Path
|
|
||||||
import glob
|
|
||||||
|
|
||||||
if not lora_dir or not Path(lora_dir).is_dir():
|
|
||||||
print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
|
|
||||||
return [], [], [], "", "", False
|
|
||||||
|
|
||||||
# Gather LoRA files
|
|
||||||
loras = sorted(
|
|
||||||
glob.glob(os.path.join(lora_dir, "*.sft"))
|
|
||||||
+ glob.glob(os.path.join(lora_dir, "*.safetensors"))
|
|
||||||
)
|
|
||||||
loras_names = [Path(x).stem for x in loras]
|
|
||||||
|
|
||||||
# Offload them with no activation
|
|
||||||
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)
|
|
||||||
|
|
||||||
# If user gave a preset, apply it
|
|
||||||
default_loras_choices = []
|
|
||||||
default_loras_multis_str = ""
|
|
||||||
default_prompt_prefix = ""
|
|
||||||
preset_applied_full_prompt = False
|
|
||||||
if lora_preset:
|
|
||||||
loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
|
|
||||||
default_loras_choices = loras_choices
|
|
||||||
# If user stored loras_mult as a list or string in JSON, unify that to str
|
|
||||||
if isinstance(loras_mult, list):
|
|
||||||
# Just store them in a single line
|
|
||||||
default_loras_multis_str = " ".join([str(x) for x in loras_mult])
|
|
||||||
else:
|
|
||||||
default_loras_multis_str = str(loras_mult)
|
|
||||||
default_prompt_prefix = prefix
|
|
||||||
preset_applied_full_prompt = full_prompt
|
|
||||||
|
|
||||||
return (
|
|
||||||
loras,
|
|
||||||
loras_names,
|
|
||||||
default_loras_choices,
|
|
||||||
default_loras_multis_str,
|
|
||||||
default_prompt_prefix,
|
|
||||||
preset_applied_full_prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse_loras_and_activate(
|
|
||||||
transformer,
|
|
||||||
loras,
|
|
||||||
loras_choices,
|
|
||||||
loras_mult_str,
|
|
||||||
num_inference_steps
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Activate the chosen LoRAs with multipliers over the pipeline's transformer.
|
|
||||||
Supports stepwise expansions (like "0.5,0.8" for partial steps).
|
|
||||||
"""
|
|
||||||
if not loras or not loras_choices:
|
|
||||||
# no LoRAs selected
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle multipliers
|
|
||||||
def is_float_or_comma_list(x):
|
|
||||||
"""
|
|
||||||
Example: "0.5", or "0.8,1.0", etc. is valid.
|
|
||||||
"""
|
|
||||||
if not x:
|
|
||||||
return False
|
|
||||||
for chunk in x.split(","):
|
|
||||||
try:
|
|
||||||
float(chunk.strip())
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Convert multiline or spaced lines to a single list
|
|
||||||
lines = [
|
|
||||||
line.strip()
|
|
||||||
for line in loras_mult_str.replace("\r", "\n").split("\n")
|
|
||||||
if line.strip() and not line.strip().startswith("#")
|
|
||||||
]
|
|
||||||
# Now combine them by space
|
|
||||||
joined_line = " ".join(lines) # "1.0 2.0,3.0"
|
|
||||||
if not joined_line.strip():
|
|
||||||
multipliers = []
|
|
||||||
else:
|
|
||||||
multipliers = joined_line.split(" ")
|
|
||||||
|
|
||||||
# Expand each item
|
|
||||||
final_multipliers = []
|
|
||||||
for mult in multipliers:
|
|
||||||
mult = mult.strip()
|
|
||||||
if not mult:
|
|
||||||
continue
|
|
||||||
if is_float_or_comma_list(mult):
|
|
||||||
# Could be "0.7" or "0.5,0.6"
|
|
||||||
if "," in mult:
|
|
||||||
# expand over steps
|
|
||||||
chunk_vals = [float(x.strip()) for x in mult.split(",")]
|
|
||||||
expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
|
|
||||||
final_multipliers.append(expanded)
|
|
||||||
else:
|
|
||||||
final_multipliers.append(float(mult))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid LoRA multiplier: '{mult}'")
|
|
||||||
|
|
||||||
# If fewer multipliers than chosen LoRAs => pad with 1.0
|
|
||||||
needed = len(loras_choices) - len(final_multipliers)
|
|
||||||
if needed > 0:
|
|
||||||
final_multipliers += [1.0]*needed
|
|
||||||
|
|
||||||
# Actually activate them
|
|
||||||
offload.activate_loras(transformer, loras_choices, final_multipliers)
|
|
||||||
|
|
||||||
def expand_list_over_steps(short_list, num_steps):
|
|
||||||
"""
|
|
||||||
If user gave (0.5, 0.8) for example, expand them over `num_steps`.
|
|
||||||
The expansion is simply linear slice across steps.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
inc = len(short_list) / float(num_steps)
|
|
||||||
idxf = 0.0
|
|
||||||
for _ in range(num_steps):
|
|
||||||
value = short_list[int(idxf)]
|
|
||||||
result.append(value)
|
|
||||||
idxf += inc
|
|
||||||
return result
|
|
||||||
|
|
||||||
def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
|
|
||||||
"""
|
|
||||||
Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.
|
|
||||||
If not, downloads them from a Hugging Face Hub repo.
|
|
||||||
Adjust the 'repo_id' and needed files as appropriate.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
try:
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"huggingface_hub is required for automatic model download. "
|
|
||||||
"Please install it via `pip install huggingface_hub`."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
# Identify just the filename portion for each path
|
|
||||||
def basename(path_str):
|
|
||||||
return os.path.basename(path_str)
|
|
||||||
|
|
||||||
repo_id = "DeepBeepMeep/Wan2.1"
|
|
||||||
target_root = local_folder
|
|
||||||
|
|
||||||
# You can customize this list as needed for i2v usage.
|
|
||||||
# At minimum you need:
|
|
||||||
# 1) The requested i2v transformer file
|
|
||||||
# 2) The requested text encoder file
|
|
||||||
# 3) VAE file
|
|
||||||
# 4) The open-clip xlm-roberta-large weights
|
|
||||||
#
|
|
||||||
# If your i2v config references additional files, add them here.
|
|
||||||
needed_files = [
|
|
||||||
"Wan2.1_VAE.pth",
|
|
||||||
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
|
||||||
basename(text_encoder_filename),
|
|
||||||
basename(transformer_filename_i2v),
|
|
||||||
]
|
|
||||||
|
|
||||||
# The original script also downloads an entire "xlm-roberta-large" folder
|
|
||||||
# via snapshot_download. If you require that for your pipeline,
|
|
||||||
# you can add it here, for example:
|
|
||||||
subfolder_name = "xlm-roberta-large"
|
|
||||||
if not Path(os.path.join(target_root, subfolder_name)).exists():
|
|
||||||
snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)
|
|
||||||
|
|
||||||
for filename in needed_files:
|
|
||||||
local_path = os.path.join(target_root, filename)
|
|
||||||
if not os.path.isfile(local_path):
|
|
||||||
print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
filename=filename,
|
|
||||||
local_dir=target_root
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Already present
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("All required i2v files are present.")
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
|
||||||
# ARGUMENT PARSER
|
|
||||||
# --------------------------------------------------
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Image-to-Video inference using WAN 2.1 i2v"
|
|
||||||
)
|
|
||||||
# Model + Tools
|
|
||||||
parser.add_argument(
|
|
||||||
"--quantize-transformer",
|
|
||||||
action="store_true",
|
|
||||||
help="Use on-the-fly transformer quantization"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--compile",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable PyTorch 2.0 compile for the transformer"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--attention",
|
|
||||||
type=str,
|
|
||||||
default="auto",
|
|
||||||
help="Which attention to use: auto, sdpa, sage, sage2, flash"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--profile",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--preload",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Verbosity level [0..5]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# i2v Model
|
|
||||||
parser.add_argument(
|
|
||||||
"--transformer-file",
|
|
||||||
type=str,
|
|
||||||
default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
|
|
||||||
help="Which i2v model to load"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--text-encoder-file",
|
|
||||||
type=str,
|
|
||||||
default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
|
|
||||||
help="Which text encoder to use"
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-dir",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Path to a directory containing i2v LoRAs"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lora-preset",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="A .lset preset name in the lora_dir to auto-apply"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generation Options
|
|
||||||
parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
|
|
||||||
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
|
||||||
parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
|
|
||||||
parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
|
|
||||||
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
|
|
||||||
parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
|
|
||||||
parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
|
|
||||||
parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
|
|
||||||
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
|
||||||
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
|
||||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
|
||||||
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
|
||||||
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
|
||||||
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
|
||||||
|
|
||||||
# LoRA usage
|
|
||||||
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
|
||||||
parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")
|
|
||||||
|
|
||||||
# Input
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-image",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
required=True,
|
|
||||||
help="Path to an input image (or multiple)."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-file",
|
|
||||||
type=str,
|
|
||||||
default="output.mp4",
|
|
||||||
help="Where to save the resulting video."
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
|
||||||
# MAIN
|
|
||||||
# --------------------------------------------------
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = parse_args()
|
|
||||||
|
|
||||||
# Setup environment
|
|
||||||
offload.default_verboseLevel = args.verbose
|
|
||||||
installed_attn_modes = get_attention_modes()
|
|
||||||
|
|
||||||
# Decide attention
|
|
||||||
chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
|
|
||||||
offload.shared_state["_attention"] = chosen_attention
|
|
||||||
|
|
||||||
# Determine i2v resolution format
|
|
||||||
if "720" in args.transformer_file:
|
|
||||||
is_720p = True
|
|
||||||
else:
|
|
||||||
is_720p = False
|
|
||||||
|
|
||||||
# Make sure we have the needed models locally
|
|
||||||
download_models_if_needed(args.transformer_file, args.text_encoder_file)
|
|
||||||
|
|
||||||
# Load i2v
|
|
||||||
wan_model, pipe = load_i2v_model(
|
|
||||||
model_filename=args.transformer_file,
|
|
||||||
text_encoder_filename=args.text_encoder_file,
|
|
||||||
is_720p=is_720p
|
|
||||||
)
|
|
||||||
wan_model._interrupt = False
|
|
||||||
|
|
||||||
# Offload / profile
|
|
||||||
# e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
|
|
||||||
# pass the budgets if you want, etc.
|
|
||||||
kwargs = {}
|
|
||||||
if args.profile == 2 or args.profile == 4:
|
|
||||||
# preload is in MB
|
|
||||||
if args.preload == 0:
|
|
||||||
budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
|
|
||||||
else:
|
|
||||||
budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
|
|
||||||
kwargs["budgets"] = budgets
|
|
||||||
elif args.profile == 3:
|
|
||||||
kwargs["budgets"] = {"*": "70%"}
|
|
||||||
|
|
||||||
compile_choice = "transformer" if args.compile else ""
|
|
||||||
# Create the offload object
|
|
||||||
offloadobj = offload.profile(
|
|
||||||
pipe,
|
|
||||||
profile_no=args.profile,
|
|
||||||
compile=compile_choice,
|
|
||||||
quantizeTransformer=args.quantize_transformer,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# If user wants to use LoRAs
|
|
||||||
(
|
|
||||||
loras,
|
|
||||||
loras_names,
|
|
||||||
default_loras_choices,
|
|
||||||
default_loras_multis_str,
|
|
||||||
preset_prompt_prefix,
|
|
||||||
preset_full_prompt
|
|
||||||
) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)
|
|
||||||
|
|
||||||
# Combine user prompt with preset prompt if the preset indicates so
|
|
||||||
if preset_prompt_prefix:
|
|
||||||
if preset_full_prompt:
|
|
||||||
# Full override
|
|
||||||
user_prompt = preset_prompt_prefix
|
|
||||||
else:
|
|
||||||
# Just prefix
|
|
||||||
user_prompt = preset_prompt_prefix + "\n" + args.prompt
|
|
||||||
else:
|
|
||||||
user_prompt = args.prompt
|
|
||||||
|
|
||||||
# Actually parse user LoRA choices if they did not rely purely on the preset
|
|
||||||
if args.loras_choices:
|
|
||||||
# If user gave e.g. "0,1", we treat that as new additions
|
|
||||||
lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
|
|
||||||
else:
|
|
||||||
# Use the defaults from the preset
|
|
||||||
lora_choice_list = default_loras_choices
|
|
||||||
|
|
||||||
# Activate them
|
|
||||||
parse_loras_and_activate(
|
|
||||||
pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
|
|
||||||
)
|
|
||||||
|
|
||||||
# Negative prompt
|
|
||||||
negative_prompt = args.negative_prompt or ""
|
|
||||||
|
|
||||||
# Sanity check resolution
|
|
||||||
if "*" in args.resolution.lower():
|
|
||||||
print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
|
|
||||||
resolution_str = args.resolution.lower().replace("*", "x")
|
|
||||||
else:
|
|
||||||
resolution_str = args.resolution
|
|
||||||
|
|
||||||
try:
|
|
||||||
width, height = [int(x) for x in resolution_str.split("x")]
|
|
||||||
except:
|
|
||||||
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
|
||||||
|
|
||||||
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
|
||||||
if args.slg_layers:
|
|
||||||
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
|
||||||
else:
|
|
||||||
slg_list = None
|
|
||||||
|
|
||||||
# Additional checks (from your original code).
|
|
||||||
if "480p" in args.transformer_file:
|
|
||||||
# Then we cannot exceed certain area for 480p model
|
|
||||||
if width * height > 832*480:
|
|
||||||
raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
|
|
||||||
# etc.
|
|
||||||
|
|
||||||
# Handle random seed
|
|
||||||
if args.seed < 0:
|
|
||||||
args.seed = random.randint(0, 999999999)
|
|
||||||
print(f"Using seed={args.seed}")
|
|
||||||
|
|
||||||
# Setup tea cache if needed
|
|
||||||
trans = wan_model.model
|
|
||||||
trans.enable_cache = (args.teacache > 0)
|
|
||||||
if trans.enable_cache:
|
|
||||||
if "480p" in args.transformer_file:
|
|
||||||
# example from your code
|
|
||||||
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
|
||||||
elif "720p" in args.transformer_file:
|
|
||||||
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
|
||||||
else:
|
|
||||||
raise ValueError("Teacache not supported for this model variant")
|
|
||||||
|
|
||||||
# Attempt generation
|
|
||||||
print("Starting generation ...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Read the input image
|
|
||||||
if not os.path.isfile(args.input_image):
|
|
||||||
raise ValueError(f"Input image does not exist: {args.input_image}")
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
input_img = Image.open(args.input_image).convert("RGB")
|
|
||||||
|
|
||||||
# Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration
|
|
||||||
|
|
||||||
# Define the generation call
|
|
||||||
# - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
|
|
||||||
# You can correct to that if needed:
|
|
||||||
frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1
|
|
||||||
# RIFLEx
|
|
||||||
enable_riflex = args.riflex
|
|
||||||
|
|
||||||
# If teacache => reset counters
|
|
||||||
if trans.enable_cache:
|
|
||||||
trans.teacache_counter = 0
|
|
||||||
trans.cache_multiplier = args.teacache
|
|
||||||
trans.cache_start_step = int(args.teacache_start * args.steps / 100.0)
|
|
||||||
trans.num_steps = args.steps
|
|
||||||
trans.cache_skipped_steps = 0
|
|
||||||
trans.previous_residual_uncond = None
|
|
||||||
trans.previous_residual_cond = None
|
|
||||||
|
|
||||||
# VAE Tiling
|
|
||||||
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
|
||||||
if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM
|
|
||||||
use_vae_config = 1
|
|
||||||
elif device_mem_capacity >= 8000:
|
|
||||||
use_vae_config = 2
|
|
||||||
else:
|
|
||||||
use_vae_config = 3
|
|
||||||
|
|
||||||
if use_vae_config == 1:
|
|
||||||
VAE_tile_size = 0
|
|
||||||
elif use_vae_config == 2:
|
|
||||||
VAE_tile_size = 256
|
|
||||||
else:
|
|
||||||
VAE_tile_size = 128
|
|
||||||
|
|
||||||
print('Using VAE tile size of', VAE_tile_size)
|
|
||||||
|
|
||||||
# Actually run the i2v generation
|
|
||||||
try:
|
|
||||||
sample_frames = wan_model.generate(
|
|
||||||
input_prompt = user_prompt,
|
|
||||||
image_start = input_img,
|
|
||||||
frame_num=frame_count,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom
|
|
||||||
shift=args.flow_shift,
|
|
||||||
sampling_steps=args.steps,
|
|
||||||
guide_scale=args.guidance_scale,
|
|
||||||
n_prompt=negative_prompt,
|
|
||||||
seed=args.seed,
|
|
||||||
offload_model=False,
|
|
||||||
callback=None, # or define your own callback if you want
|
|
||||||
enable_RIFLEx=enable_riflex,
|
|
||||||
VAE_tile_size=VAE_tile_size,
|
|
||||||
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
|
||||||
slg_layers=slg_list,
|
|
||||||
slg_start=args.slg_start,
|
|
||||||
slg_end=args.slg_end,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
offloadobj.unload_all()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
err_str = f"Generation failed with error: {e}"
|
|
||||||
# Attempt to detect OOM errors
|
|
||||||
s = str(e).lower()
|
|
||||||
if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
|
|
||||||
raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
|
|
||||||
else:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise RuntimeError(err_str)
|
|
||||||
|
|
||||||
# After generation
|
|
||||||
offloadobj.unload_all()
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if sample_frames is None:
|
|
||||||
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
|
||||||
|
|
||||||
# If teacache was used, we can see how many steps were skipped
|
|
||||||
if trans.enable_cache:
|
|
||||||
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
|
||||||
|
|
||||||
# Save result
|
|
||||||
sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W]
|
|
||||||
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
|
|
||||||
|
|
||||||
# Use the provided helper from your code to store the MP4
|
|
||||||
# By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
|
|
||||||
# or you can do your own. We'll do the same for consistency:
|
|
||||||
cache_video(
|
|
||||||
tensor=sample_frames[None], # shape => [1, c, T, H, W]
|
|
||||||
save_file=args.output_file,
|
|
||||||
fps=16,
|
|
||||||
nrow=1,
|
|
||||||
normalize=True,
|
|
||||||
value_range=(-1, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
elapsed_s = end_time - start_time
|
|
||||||
print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -458,6 +458,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
callback=None,
|
callback=None,
|
||||||
pipeline=None,
|
pipeline=None,
|
||||||
loras_slists=None,
|
loras_slists=None,
|
||||||
|
joint_pass= True,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Function invoked when calling the pipeline for generation.
|
Function invoked when calling the pipeline for generation.
|
||||||
@ -656,43 +657,55 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
if do_true_cfg and joint_pass:
|
||||||
hidden_states=latents,
|
noise_pred, neg_noise_pred = self.transformer(
|
||||||
timestep=timestep / 1000,
|
|
||||||
guidance=guidance,
|
|
||||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
img_shapes=img_shapes,
|
|
||||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
|
||||||
attention_kwargs=self.attention_kwargs,
|
|
||||||
return_dict=False,
|
|
||||||
**kwargs
|
|
||||||
)[0]
|
|
||||||
if noise_pred == None: return None
|
|
||||||
|
|
||||||
|
|
||||||
if do_true_cfg:
|
|
||||||
# with self.transformer.cache_context("uncond"):
|
|
||||||
neg_noise_pred = self.transformer(
|
|
||||||
hidden_states=latents,
|
hidden_states=latents,
|
||||||
timestep=timestep / 1000,
|
timestep=timestep / 1000,
|
||||||
guidance=guidance,
|
guidance=guidance,
|
||||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
|
||||||
encoder_hidden_states=negative_prompt_embeds,
|
encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
|
||||||
img_shapes=img_shapes,
|
img_shapes=img_shapes,
|
||||||
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
|
txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist(),negative_prompt_embeds_mask.sum(dim=1).tolist()],
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
if noise_pred == None: return None
|
||||||
|
else:
|
||||||
|
noise_pred = self.transformer(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask_list=[prompt_embeds_mask],
|
||||||
|
encoder_hidden_states_list=[prompt_embeds],
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens_list=[prompt_embeds_mask.sum(dim=1).tolist()],
|
||||||
attention_kwargs=self.attention_kwargs,
|
attention_kwargs=self.attention_kwargs,
|
||||||
return_dict=False,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)[0]
|
)[0]
|
||||||
if neg_noise_pred == None: return None
|
if noise_pred == None: return None
|
||||||
|
|
||||||
|
if do_true_cfg:
|
||||||
|
neg_noise_pred = self.transformer(
|
||||||
|
hidden_states=latents,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=guidance,
|
||||||
|
encoder_hidden_states_mask_list=[negative_prompt_embeds_mask],
|
||||||
|
encoder_hidden_states_list=[negative_prompt_embeds],
|
||||||
|
img_shapes=img_shapes,
|
||||||
|
txt_seq_lens_list=[negative_prompt_embeds_mask.sum(dim=1).tolist()],
|
||||||
|
attention_kwargs=self.attention_kwargs,
|
||||||
|
**kwargs
|
||||||
|
)[0]
|
||||||
|
if neg_noise_pred == None: return None
|
||||||
|
|
||||||
|
if do_true_cfg:
|
||||||
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||||
if comb_pred == None: return None
|
if comb_pred == None: return None
|
||||||
|
|
||||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||||
|
neg_noise_pred = None
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import torch
|
|||||||
|
|
||||||
def get_qwen_text_encoder_filename(text_encoder_quantization):
|
def get_qwen_text_encoder_filename(text_encoder_quantization):
|
||||||
text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors"
|
text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors"
|
||||||
if text_encoder_quantization =="int8" and False:
|
if text_encoder_quantization =="int8":
|
||||||
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
|
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
|
||||||
return text_encoder_filename
|
return text_encoder_filename
|
||||||
|
|
||||||
@ -11,10 +11,8 @@ class family_handler():
|
|||||||
def query_model_def(base_model_type, model_def):
|
def query_model_def(base_model_type, model_def):
|
||||||
model_def_output = {
|
model_def_output = {
|
||||||
"image_outputs" : True,
|
"image_outputs" : True,
|
||||||
"no_negative_prompt" : True,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
model_def_output["embedded_guidance"] = True
|
|
||||||
|
|
||||||
return model_def_output
|
return model_def_output
|
||||||
|
|
||||||
@ -69,7 +67,7 @@ class family_handler():
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"embedded_guidance": 4,
|
"guidance_scale": 4,
|
||||||
})
|
})
|
||||||
if model_def.get("reference_image", False):
|
if model_def.get("reference_image", False):
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
|
|||||||
@ -73,6 +73,7 @@ class model_factory():
|
|||||||
self,
|
self,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
||||||
|
n_prompt = None,
|
||||||
sampling_steps: int = 20,
|
sampling_steps: int = 20,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
width= 832,
|
width= 832,
|
||||||
@ -84,6 +85,7 @@ class model_factory():
|
|||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
video_prompt_type = "",
|
video_prompt_type = "",
|
||||||
VAE_tile_size = None,
|
VAE_tile_size = None,
|
||||||
|
joint_pass = True,
|
||||||
**bbargs
|
**bbargs
|
||||||
):
|
):
|
||||||
# Generate with different aspect ratios
|
# Generate with different aspect ratios
|
||||||
@ -102,8 +104,12 @@ class model_factory():
|
|||||||
|
|
||||||
# width, height = aspect_ratios["16:9"]
|
# width, height = aspect_ratios["16:9"]
|
||||||
|
|
||||||
|
if n_prompt is None or len(n_prompt) == 0:
|
||||||
|
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
||||||
|
|
||||||
image = self.pipeline(
|
image = self.pipeline(
|
||||||
prompt=input_prompt,
|
prompt=input_prompt,
|
||||||
|
negative_prompt=n_prompt,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
num_inference_steps=sampling_steps,
|
num_inference_steps=sampling_steps,
|
||||||
@ -112,6 +118,7 @@ class model_factory():
|
|||||||
callback = callback,
|
callback = callback,
|
||||||
pipeline=self,
|
pipeline=self,
|
||||||
loras_slists=loras_slists,
|
loras_slists=loras_slists,
|
||||||
|
joint_pass = joint_pass,
|
||||||
generator=torch.Generator(device="cuda").manual_seed(seed)
|
generator=torch.Generator(device="cuda").manual_seed(seed)
|
||||||
)
|
)
|
||||||
if image is None: return None
|
if image is None: return None
|
||||||
|
|||||||
@ -512,52 +512,29 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor = None,
|
encoder_hidden_states_list = None,
|
||||||
encoder_hidden_states_mask: torch.Tensor = None,
|
encoder_hidden_states_mask_list = None,
|
||||||
timestep: torch.LongTensor = None,
|
timestep: torch.LongTensor = None,
|
||||||
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
||||||
txt_seq_lens: Optional[List[int]] = None,
|
txt_seq_lens_list = None,
|
||||||
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
||||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
return_dict: bool = True,
|
|
||||||
callback= None,
|
callback= None,
|
||||||
pipeline =None,
|
pipeline =None,
|
||||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||||
"""
|
|
||||||
The [`QwenTransformer2DModel`] forward method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
|
||||||
Input `hidden_states`.
|
|
||||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
|
||||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
||||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
|
||||||
Mask of the input conditions.
|
|
||||||
timestep ( `torch.LongTensor`):
|
|
||||||
Used to indicate denoising step.
|
|
||||||
attention_kwargs (`dict`, *optional*):
|
|
||||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
||||||
`self.processor` in
|
|
||||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
|
||||||
tuple.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
||||||
`tuple` where the first element is the sample tensor.
|
|
||||||
"""
|
|
||||||
if attention_kwargs is not None:
|
|
||||||
attention_kwargs = attention_kwargs.copy()
|
|
||||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
||||||
else:
|
|
||||||
lora_scale = 1.0
|
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|
||||||
timestep = timestep.to(hidden_states.dtype)
|
timestep = timestep.to(hidden_states.dtype)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
hidden_states_list = [hidden_states if i == 0 else hidden_states.clone() for i, _ in enumerate(encoder_hidden_states_list)]
|
||||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
||||||
|
new_encoder_hidden_states_list = []
|
||||||
|
for encoder_hidden_states in encoder_hidden_states_list:
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
new_encoder_hidden_states_list.append(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_list = new_encoder_hidden_states_list
|
||||||
|
new_encoder_hidden_states_list = encoder_hidden_states = None
|
||||||
|
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||||
@ -568,27 +545,30 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
image_rotary_emb_list = [ self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for txt_seq_lens in txt_seq_lens_list]
|
||||||
|
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
for index_block, block in enumerate(self.transformer_blocks):
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None, False, True)
|
callback(-1, None, False, True)
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
return [None]
|
return [None] * len(hidden_states_list)
|
||||||
encoder_hidden_states, hidden_states = block(
|
for hidden_states, encoder_hidden_states, encoder_hidden_states_mask, image_rotary_emb in zip(hidden_states_list, encoder_hidden_states_list, encoder_hidden_states_mask_list, image_rotary_emb_list):
|
||||||
hidden_states=hidden_states,
|
encoder_hidden_states[...], hidden_states[...] = block(
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
hidden_states=hidden_states,
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
image_rotary_emb=image_rotary_emb,
|
temb=temb,
|
||||||
joint_attention_kwargs=attention_kwargs,
|
image_rotary_emb=image_rotary_emb,
|
||||||
)
|
joint_attention_kwargs=attention_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Use only the image part (hidden_states) from the dual-stream blocks
|
# Use only the image part (hidden_states) from the dual-stream blocks
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
output_list = []
|
||||||
output = self.proj_out(hidden_states)
|
for i in range(len(hidden_states_list)):
|
||||||
|
hidden_states = self.norm_out(hidden_states_list[i], temb)
|
||||||
|
hidden_states_list[i] = None
|
||||||
|
output_list.append(self.proj_out(hidden_states))
|
||||||
|
|
||||||
if not return_dict:
|
return output_list
|
||||||
return (output,)
|
|
||||||
|
|
||||||
return Transformer2DModelOutput(sample=output)
|
|
||||||
|
|||||||
@ -131,7 +131,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
def remux_with_audio(video_path: Path, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
||||||
from wan.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
|
from shared.utils.utils import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||||
temp_path = Path(f.name)
|
temp_path = Path(f.name)
|
||||||
|
|||||||
@ -184,7 +184,7 @@ class VaceVideoProcessor(object):
|
|||||||
|
|
||||||
|
|
||||||
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0):
|
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= 0, start_frame =0):
|
||||||
from wan.utils.utils import resample
|
from shared.utils.utils import resample
|
||||||
|
|
||||||
target_fps = self.max_fps
|
target_fps = self.max_fps
|
||||||
|
|
||||||
|
|||||||
2
wgp.py
2
wgp.py
@ -50,7 +50,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
|||||||
PROMPT_VARS_MAX = 10
|
PROMPT_VARS_MAX = 10
|
||||||
|
|
||||||
target_mmgp_version = "3.5.7"
|
target_mmgp_version = "3.5.7"
|
||||||
WanGP_version = "7.71"
|
WanGP_version = "7.72"
|
||||||
settings_version = 2.23
|
settings_version = 2.23
|
||||||
max_source_video_frames = 3000
|
max_source_video_frames = 3000
|
||||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user