Refactored Loras

This commit is contained in:
DeepBeepMeep 2025-03-14 23:43:04 +01:00
parent 48b777ae6c
commit d233dd7ed9
7 changed files with 344 additions and 115 deletions

View File

@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!! ## 🔥 Latest News!!
* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt* * Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user) * Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated * Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
@ -157,6 +158,11 @@ python gradio_server.py --attention sdpa
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" . Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
If you want to manage in differenta areas Loras for the 1.3B model and the 14B as they are not comptatible, just create the following subfolders:
- loras/1.3B
- loras/14B
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\ For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8. *1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.

View File

@ -170,6 +170,18 @@ def _parse_args():
return args return args
def get_lora_dir(root_lora_dir):
if not use_image2video:
if "1.3B" in transformer_filename_t2v:
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
if os.path.isdir(lora_dir_1_3B ):
return lora_dir_1_3B
else:
lora_dir_14B = os.path.join(root_lora_dir, "14B")
if os.path.isdir(lora_dir_14B ):
return lora_dir_14B
return root_lora_dir
attention_modes_supported = get_attention_modes() attention_modes_supported = get_attention_modes()
args = _parse_args() args = _parse_args()
@ -238,9 +250,10 @@ if args.i2v:
lora_dir =args.lora_dir lora_dir =args.lora_dir
if use_image2video and len(lora_dir)==0: if use_image2video and len(lora_dir)==0:
lora_dir =args.lora_dir_i2v root_lora_dir =args.lora_dir_i2v
if len(lora_dir) ==0: if len(lora_dir) ==0:
lora_dir = "loras_i2v" if use_image2video else "loras" root_lora_dir = "loras_i2v" if use_image2video else "loras"
lora_dir = get_lora_dir(root_lora_dir)
lora_preselected_preset = args.lora_preset lora_preselected_preset = args.lora_preset
default_tea_cache = 0 default_tea_cache = 0
# if args.fast : #or args.fastest # if args.fast : #or args.fastest
@ -295,35 +308,42 @@ def sanitize_file_name(file_name, rep =""):
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep) return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
def extract_preset(lset_name, loras): def extract_preset(lset_name, loras):
loras_choices = []
loras_choices_files = []
loras_mult_choices = ""
prompt =""
full_prompt =""
lset_name = sanitize_file_name(lset_name) lset_name = sanitize_file_name(lset_name)
if not lset_name.endswith(".lset"): if not lset_name.endswith(".lset"):
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" ) lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
else: else:
lset_name_filename = os.path.join(lora_dir, lset_name ) lset_name_filename = os.path.join(lora_dir, lset_name )
error = ""
if not os.path.isfile(lset_name_filename): if not os.path.isfile(lset_name_filename):
raise gr.Error(f"Preset '{lset_name}' not found ") error = f"Preset '{lset_name}' not found "
else:
missing_loras = []
with open(lset_name_filename, "r", encoding="utf-8") as reader: with open(lset_name_filename, "r", encoding="utf-8") as reader:
text = reader.read() text = reader.read()
lset = json.loads(text) lset = json.loads(text)
loras_choices_files = lset["loras"] loras_choices_files = lset["loras"]
loras_choices = [] for lora_file in loras_choices_files:
missing_loras = [] choice = os.path.join(lora_dir, lora_file)
for lora_file in loras_choices_files: if choice not in loras:
loras_choice_no = loras.index(os.path.join(lora_dir, lora_file)) missing_loras.append(lora_file)
if loras_choice_no < 0: else:
missing_loras.append(lora_file) loras_choice_no = loras.index(choice)
else: loras_choices.append(str(loras_choice_no))
loras_choices.append(str(loras_choice_no))
if len(missing_loras) > 0: if len(missing_loras) > 0:
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}") error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}"
loras_mult_choices = lset["loras_mult"] loras_mult_choices = lset["loras_mult"]
prompt = lset.get("prompt", "") prompt = lset.get("prompt", "")
return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False) full_prompt = lset.get("full_prompt", False)
return loras_choices, loras_mult_choices, prompt, full_prompt, error
def get_default_prompt(i2v): def get_default_prompt(i2v):
if i2v: if i2v:
@ -332,7 +352,7 @@ def get_default_prompt(i2v):
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect." return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None): def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
loras =[] loras =[]
loras_names = [] loras_names = []
default_loras_choices = [] default_loras_choices = []
@ -341,32 +361,35 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
default_lora_preset = "" default_lora_preset = ""
default_prompt = "" default_prompt = ""
if use_image2video or not "1.3B" in transformer_filename_t2v: from pathlib import Path
from pathlib import Path
if lora_dir != None : if lora_dir != None :
if not os.path.isdir(lora_dir): if not os.path.isdir(lora_dir):
raise Exception("--lora-dir should be a path to a directory that contains Loras") raise Exception("--lora-dir should be a path to a directory that contains Loras")
if lora_dir != None:
import glob
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
dir_loras.sort()
loras += [element for element in dir_loras if element not in loras ]
dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") ) if lora_dir != None:
dir_presets.sort() import glob
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets] dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
dir_loras.sort()
loras += [element for element in dir_loras if element not in loras ]
if len(loras) > 0: dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
loras_names = [ Path(lora).stem for lora in loras ] dir_presets.sort()
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier, loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
if len(lora_preselected_preset) > 0: loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
raise Exception(f"Unknown preset '{lora_preselected_preset}'") if len(loras) > 0:
default_lora_preset = lora_preselected_preset loras_names = [ Path(lora).stem for lora in loras ]
default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
if len(lora_preselected_preset) > 0:
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
default_lora_preset = lora_preselected_preset
default_loras_choices, default_loras_multis_str, default_prompt, _ , error = extract_preset(default_lora_preset, loras)
if len(error) > 0:
print(error[:200])
if len(default_prompt) == 0: if len(default_prompt) == 0:
default_prompt = get_default_prompt(use_image2video) default_prompt = get_default_prompt(use_image2video)
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
@ -450,8 +473,8 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
kwargs["budgets"] = { "*" : "70%" } kwargs["budgets"] = { "*" : "70%" }
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None) offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs) loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe["transformer"], lora_dir, lora_preselected_preset, None)
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
@ -542,7 +565,7 @@ def apply_changes( state,
state["config_new"] = server_config state["config_new"] = server_config
state["config_old"] = old_server_config state["config_old"] = old_server_config
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
attention_mode = server_config["attention_mode"] attention_mode = server_config["attention_mode"]
profile = server_config["profile"] profile = server_config["profile"]
compile = server_config["compile"] compile = server_config["compile"]
@ -560,7 +583,7 @@ def apply_changes( state,
offloadobj.release() offloadobj.release()
offloadobj = None offloadobj = None
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>" yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
lora_dir = get_lora_dir(root_lora_dir)
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset ) wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
@ -590,7 +613,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
num_inference_steps, flow_shift = get_default_flow(trans_file) num_inference_steps, flow_shift = get_default_flow(trans_file)
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] ) header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
return num_inference_steps, flow_shift, header new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append( (new_preset_msg, ""))
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
from moviepy.editor import ImageSequenceClip from moviepy.editor import ImageSequenceClip
@ -603,23 +630,32 @@ def save_video(final_frames, output_path, fps=24):
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None) ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
def build_callback(state, pipe, progress, status, num_inference_steps): def build_callback(state, pipe, progress, status, num_inference_steps):
def callback(step_idx, latents): def callback(step_idx, latents, read_state = False):
step_idx += 1 status = state["progress_status"]
if state.get("abort", False): if read_state:
# pipe._interrupt = True phase, step_idx = state["progress_phase"]
status_msg = status + " - Aborting"
elif step_idx == num_inference_steps:
status_msg = status + " - VAE Decoding"
else: else:
status_msg = status + " - Denoising" step_idx += 1
if state.get("abort", False):
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps) # pipe._interrupt = True
phase = " - Aborting"
elif step_idx == num_inference_steps:
phase = " - VAE Decoding"
else:
phase = " - Denoising"
state["progress_phase"] = (phase, step_idx)
status_msg = status + phase
if step_idx >= 0:
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
else:
progress(0, status_msg)
return callback return callback
def abort_generation(state): def abort_generation(state):
if "in_progress" in state: if "in_progress" in state:
state["abort"] = True state["abort"] = True
state["extra_orders"] = 0
wan_model._interrupt= True wan_model._interrupt= True
return gr.Button(interactive= False) return gr.Button(interactive= False)
else: else:
@ -635,10 +671,11 @@ def finalize_gallery(state):
del state["in_progress"] del state["in_progress"]
choice = state.get("selected",0) choice = state.get("selected",0)
state["extra_orders"] = 0
time.sleep(0.2) time.sleep(0.2)
global gen_in_progress global gen_in_progress
gen_in_progress = False gen_in_progress = False
return gr.Gallery(selected_index=choice), gr.Button(interactive= True) return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
def select_video(state , event_data: gr.EventData): def select_video(state , event_data: gr.EventData):
data= event_data._data data= event_data._data
@ -656,6 +693,32 @@ def expand_slist(slist, num_inference_steps ):
return new_slist return new_slist
def one_more_video(state):
extra_orders = state.get("extra_orders", 0)
extra_orders += 1
state["extra_orders"] = extra_orders
prompts_max = state["prompts_max"]
prompt_no = state["prompt_no"]
video_no = state["video_no"]
total_video = state["total_video"]
# total_video += (prompts_max- prompt_no)
total_video += 1
total_generation = state["total_generation"] + extra_orders
state["total_video"] = total_video
state["progress_status"] = f"Video {video_no}/{total_video}"
offload.shared_state["refresh"] = 1
# if (prompts_max - prompt_no) > 1:
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
# else:
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
return state
def prepare_generate_video():
return gr.Button(visible= False), gr.Checkbox(visible= True)
def generate_video( def generate_video(
prompt, prompt,
negative_prompt, negative_prompt,
@ -733,6 +796,7 @@ def generate_video(
else: else:
VAE_tile_size = 128 VAE_tile_size = 128
trans = wan_model.model
global gen_in_progress global gen_in_progress
gen_in_progress = True gen_in_progress = True
@ -740,7 +804,7 @@ def generate_video(
if len(prompt) ==0: if len(prompt) ==0:
return return
prompts = prompt.replace("\r", "").split("\n") prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
if len(prompts) ==0: if len(prompts) ==0:
return return
if use_image2video: if use_image2video:
@ -808,9 +872,13 @@ def generate_video(
list_mult_choices_nums.append(float(mult)) list_mult_choices_nums.append(float(mult))
if len(list_mult_choices_nums ) < len(loras_choices): if len(list_mult_choices_nums ) < len(loras_choices):
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) ) list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
offload.activate_loras(wan_model.model, loras_choices, list_mult_choices_nums) pinnedLora = False # profile !=5
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
errors = trans._loras_errors
if len(errors) > 0:
error_files = [msg for _ , msg in errors]
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
seed = None if seed == -1 else seed seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference # negative_prompt = "" # not applicable in the inference
@ -825,7 +893,6 @@ def generate_video(
joint_pass = boost ==1 joint_pass = boost ==1
# TeaCache # TeaCache
trans = wan_model.model
trans.enable_teacache = tea_cache > 0 trans.enable_teacache = tea_cache > 0
if trans.enable_teacache: if trans.enable_teacache:
if use_image2video: if use_image2video:
@ -857,11 +924,23 @@ def generate_video(
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
video_no = 0 video_no = 0
total_video = repeat_generation * len(prompts) total_video = repeat_generation * len(prompts)
state["total_video"] = total_video
extra_generation = 0
abort = False abort = False
start_time = time.time() start_time = time.time()
for prompt in prompts: state["prompts_max"] = len(prompts)
for _ in range(repeat_generation): for no, prompt in enumerate(prompts):
if abort: repeat_no = 0
state["prompt_no"] = no
extra_generation = 0
while True:
extra_orders = state.get("extra_orders",0)
state["extra_orders"] = 0
extra_generation += extra_orders
state["total_generation"] = repeat_generation + extra_generation
# total_video += (len(prompts)- no) * extra_orders
total_video += extra_orders
if abort or repeat_no >= (repeat_generation + extra_generation):
break break
if trans.enable_teacache: if trans.enable_teacache:
@ -875,9 +954,12 @@ def generate_video(
video_no += 1 video_no += 1
status = f"Video {video_no}/{total_video}" status = f"Video {video_no}/{total_video}"
state["video_no"] = video_no
state["progress_status"] = status
state["progress_phase"] = (" - Encoding Prompt", -1 )
progress(0, desc=status + " - Encoding Prompt" ) progress(0, desc=status + " - Encoding Prompt" )
callback = build_callback(state, trans, progress, status, num_inference_steps) callback = build_callback(state, trans, progress, status, num_inference_steps)
offload.shared_state["callback"] = callback
gc.collect() gc.collect()
@ -887,7 +969,7 @@ def generate_video(
if use_image2video: if use_image2video:
samples = wan_model.generate( samples = wan_model.generate(
prompt, prompt,
image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'), image_to_continue[no].convert('RGB'),
frame_num=(video_length // 4)* 4 + 1, frame_num=(video_length // 4)* 4 + 1,
max_area=MAX_AREA_CONFIGS[resolution], max_area=MAX_AREA_CONFIGS[resolution],
shift=flow_shift, shift=flow_shift,
@ -923,6 +1005,7 @@ def generate_video(
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
offload.last_offload_obj.unload_all() offload.last_offload_obj.unload_all()
offload.unload_loras_from_model(trans)
# if compile: # if compile:
# cache_size = torch._dynamo.config.cache_size_limit # cache_size = torch._dynamo.config.cache_size_limit
# torch.compiler.reset() # torch.compiler.reset()
@ -967,9 +1050,9 @@ def generate_video(
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
if os.name == 'nt': if os.name == 'nt':
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50])}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else: else:
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100])}.mp4" file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name) video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
cache_video( cache_video(
tensor=sample[None], tensor=sample[None],
@ -987,10 +1070,13 @@ def generate_video(
end_time = time.time() end_time = time.time()
yield f"Total Generation Time: {end_time-start_time:.1f}s" yield f"Total Generation Time: {end_time-start_time:.1f}s"
seed += 1 seed += 1
repeat_no += 1
if temp_filename!= None and os.path.isfile(temp_filename): if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename) os.remove(temp_filename)
gen_in_progress = False gen_in_progress = False
offload.unload_loras_from_model(trans)
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above" new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
@ -998,19 +1084,19 @@ new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
def validate_delete_lset(lset_name): def validate_delete_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg: if len(lset_name) == 0 or lset_name == new_preset_msg:
gr.Info(f"Choose a Preset to delete") gr.Info(f"Choose a Preset to delete")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False) return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
else: else:
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True) return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
def validate_save_lset(lset_name): def validate_save_lset(lset_name):
if len(lset_name) == 0 or lset_name == new_preset_msg: if len(lset_name) == 0 or lset_name == new_preset_msg:
gr.Info("Please enter a name for the preset") gr.Info("Please enter a name for the preset")
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False) return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
else: else:
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True) return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
def cancel_lset(): def cancel_lset():
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox): def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
global loras_presets global loras_presets
@ -1047,7 +1133,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
lset_choices = [ ( preset, preset) for preset in loras_presets ] lset_choices = [ ( preset, preset) for preset in loras_presets ]
lset_choices.append( (new_preset_msg, "")) lset_choices.append( (new_preset_msg, ""))
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
def delete_lset(lset_name): def delete_lset(lset_name):
global loras_presets global loras_presets
@ -1065,23 +1151,57 @@ def delete_lset(lset_name):
lset_choices = [ (preset, preset) for preset in loras_presets] lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, "")) lset_choices.append((new_preset_msg, ""))
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
def refresh_lora_list(lset_name, loras_choices):
global loras,loras_names, loras_presets
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
loras, loras_names, _, _, _, _, loras_presets = setup_loras(wan_model.model, lora_dir, lora_preselected_preset, None)
gc.collect()
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) }
lora_names_selected = []
for lora in prev_lora_names_selected:
lora_id = new_loras_dict.get(lora, None)
if lora_id!= None:
lora_names_selected.append(lora_id)
lset_choices = [ (preset, preset) for preset in loras_presets]
lset_choices.append((new_preset_msg, ""))
if lset_name in loras_presets:
pos = loras_presets.index(lset_name)
else:
pos = len(loras_presets)
lset_name =""
errors = wan_model.model._loras_errors
if len(errors) > 0:
error_files = [path for path, _ in errors]
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
else:
gr.Info("Lora List has been refreshed")
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt): def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
if len(lset_name) == 0 or lset_name== new_preset_msg: if len(lset_name) == 0 or lset_name== new_preset_msg:
gr.Info("Please choose a preset in the list or create one") gr.Info("Please choose a preset in the list or create one")
else: else:
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras) loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
if full_prompt: if len(error) > 0:
prompt = preset_prompt gr.Info(error)
elif len(preset_prompt) > 0: else:
prompts = prompt.replace("\r", "").split("\n") if full_prompt:
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] prompt = preset_prompt
prompt = "\n".join(prompts) elif len(preset_prompt) > 0:
prompt = preset_prompt + '\n' + prompt prompts = prompt.replace("\r", "").split("\n")
gr.Info(f"Lora Preset '{lset_name}' has been applied") prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
prompt = "\n".join(prompts)
prompt = preset_prompt + '\n' + prompt
gr.Info(f"Lora Preset '{lset_name}' has been applied")
return loras_choices, loras_mult_choices, prompt return loras_choices, loras_mult_choices, prompt
@ -1094,21 +1214,21 @@ def create_demo():
state = gr.State({}) state = gr.State({})
if use_image2video: if use_image2video:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
else: else:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>") gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>") gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
if use_image2video and False: if use_image2video and False:
pass pass
else: else:
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:") gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM") gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM") gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM") gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear") gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.") gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
# css = """<STYLE> # css = """<STYLE>
@ -1302,6 +1422,7 @@ def create_demo():
# with gr.Column(): # with gr.Column():
with gr.Row(height=17): with gr.Row(height=17):
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1) apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False) # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
save_lset_prompt_drop= gr.Dropdown( save_lset_prompt_drop= gr.Dropdown(
choices=[ choices=[
@ -1334,7 +1455,7 @@ def create_demo():
with gr.Row(visible=False) as advanced_row: with gr.Row(visible=False) as advanced_row:
with gr.Column(): with gr.Column():
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)") seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated Video per prompt") repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", value="") negative_prompt = gr.Textbox(label="Negative Prompt", value="")
with gr.Row(): with gr.Row():
@ -1377,22 +1498,25 @@ def create_demo():
label="Generated videos", show_label=False, elem_id="gallery" label="Generated videos", show_label=False, elem_id="gallery"
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False) , columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
generate_btn = gr.Button("Generate") generate_btn = gr.Button("Generate")
onemore_btn = gr.Button("One More Please !", visible= False)
abort_btn = gr.Button("Abort") abort_btn = gr.Button("Abort")
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]) apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
gen_status.change(refresh_gallery, inputs = [state], outputs = output ) gen_status.change(refresh_gallery, inputs = [state], outputs = output )
abort_btn.click(abort_generation,state,abort_btn ) abort_btn.click(abort_generation,state,abort_btn )
output.select(select_video, state, None ) output.select(select_video, state, None )
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
generate_btn.click( generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
fn=generate_video, fn=generate_video,
inputs=[ inputs=[
prompt, prompt,
@ -1420,7 +1544,7 @@ def create_demo():
).then( ).then(
finalize_gallery, finalize_gallery,
[state], [state],
[output , abort_btn] [output , abort_btn, generate_btn, onemore_btn]
) )
apply_btn.click( apply_btn.click(
@ -1441,7 +1565,7 @@ def create_demo():
).then( ).then(
update_defaults, update_defaults,
[state, num_inference_steps, flow_shift], [state, num_inference_steps, flow_shift],
[num_inference_steps, flow_shift, header] [num_inference_steps, flow_shift, header, lset_name , loras_choices ]
) )
return demo return demo

View File

@ -16,5 +16,5 @@ gradio>=5.0.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3
mmgp==3.2.6 mmgp==3.2.7
peft==0.14.0 peft==0.14.0

View File

@ -331,7 +331,7 @@ class WanI2V:
callback(-1, None) callback(-1, None)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(i) offload.set_step_no_for_lora(self.model, i)
latent_model_input = [latent.to(self.device)] latent_model_input = [latent.to(self.device)]
timestep = [t] timestep = [t]

View File

@ -60,6 +60,30 @@ try:
except ImportError: except ImportError:
sageattn = None sageattn = None
# # try:
# if True:
# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
# @torch.compiler.disable()
# def sageattn_window_wrapper(
# qkv_list,
# attention_length,
# window
# ):
# q,k, v = qkv_list
# padding_length = q.shape[0] -attention_length
# q = q[:attention_length, :, : ].unsqueeze(0)
# k = k[:attention_length, :, : ].unsqueeze(0)
# v = v[:attention_length, :, : ].unsqueeze(0)
# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
# del q, k ,v
# qkv_list.clear()
# if padding_length > 0:
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
# return o
# # except ImportError:
# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable() @torch.compiler.disable()
def sdpa_wrapper( def sdpa_wrapper(
@ -119,7 +143,8 @@ def pay_attention(
deterministic=False, deterministic=False,
dtype=torch.bfloat16, dtype=torch.bfloat16,
version=None, version=None,
force_attention= None force_attention= None,
cross_attn= False
): ):
""" """
q: [B, Lq, Nq, C1]. q: [B, Lq, Nq, C1].
@ -194,9 +219,67 @@ def pay_attention(
max_seqlen_kv=lk, max_seqlen_kv=lk,
).unflatten(0, (b, lq)) ).unflatten(0, (b, lq))
elif attn=="sage2": elif attn=="sage2":
qkv_list = [q,k,v] import math
del q,k,v if cross_attn or True:
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0) qkv_list = [q,k,v]
del q,k,v
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
# else:
# layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"]
# current_step = offload.shared_state["step_no"]
# max_steps = offload.shared_state["max_steps"]
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
# window = 0
# start_window_step = int(max_steps * 0.4)
# start_layer = 10
# if (layer < start_layer ) or current_step <start_window_step:
# window = 0
# else:
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
# window = math.ceil(coef* nb_latents)
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
# def flip(q):
# q = q.reshape(*embed_sizes, *q.shape[-2:])
# q = q.transpose(0,2)
# q = q.contiguous()
# q = q.transpose(0,2)
# q = q.reshape( -1, *q.shape[-2:])
# return q
# def flop(q):
# q = q.reshape(embed_sizes[2], embed_sizes[1], embed_sizes[0] , *q.shape[-2:])
# q = q.transpose(0,2)
# q = q.contiguous()
# q = q.transpose(0,2)
# q = q.reshape( -1, *q.shape[-2:])
# return q
# if invert_spaces:
# q = flip(q)
# k = flip(k)
# v = flip(v)
# qkv_list = [q,k,v]
# del q,k,v
# x = sageattn_window_wrapper(qkv_list, lq, window= window) #.unsqueeze(0)
# if invert_spaces:
# x = flop(x)
# x = x.unsqueeze(0)
elif attn=="sdpa": elif attn=="sdpa":
qkv_list = [q, k, v] qkv_list = [q, k, v]
del q, k , v del q, k , v

View File

@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin from diffusers.models.modeling_utils import ModelMixin
import numpy as np import numpy as np
from typing import Union,Optional from typing import Union,Optional
from mmgp import offload
from .attention import pay_attention from .attention import pay_attention
__all__ = ['WanModel'] __all__ = ['WanModel']
@ -302,7 +302,7 @@ class WanT2VCrossAttention(WanSelfAttention):
# compute attention # compute attention
qvl_list=[q, k, v] qvl_list=[q, k, v]
del q, k, v del q, k, v
x = pay_attention(qvl_list, k_lens=context_lens) x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
# output # output
x = x.flatten(2) x = x.flatten(2)
@ -716,7 +716,9 @@ class WanModel(ModelMixin, ConfigMixin):
pipeline = None, pipeline = None,
current_step = 0, current_step = 0,
context2 = None, context2 = None,
is_uncond=False is_uncond=False,
max_steps = 0
): ):
r""" r"""
Forward pass through the diffusion model Forward pass through the diffusion model
@ -755,6 +757,12 @@ class WanModel(ModelMixin, ConfigMixin):
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
grid_sizes = [ list(u.shape[2:]) for u in x] grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0]
offload.shared_state["embed_sizes"] = embed_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps
x = [u.flatten(2).transpose(1, 2) for u in x] x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
@ -843,7 +851,11 @@ class WanModel(ModelMixin, ConfigMixin):
# context=context, # context=context,
context_lens=context_lens) context_lens=context_lens)
for block in self.blocks: for l, block in enumerate(self.blocks):
offload.shared_state["layer"] = l
if "refresh" in offload.shared_state:
del offload.shared_state["refresh"]
offload.shared_state["callback"](-1, -1, True)
if pipeline._interrupt: if pipeline._interrupt:
if joint_pass: if joint_pass:
return None, None return None, None

View File

@ -243,6 +243,10 @@ class WanT2V:
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
if self.model.enable_teacache: if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None: if callback != None:
@ -250,7 +254,7 @@ class WanT2V:
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
timestep = [t] timestep = [t]
offload.set_step_no_for_lora(i) offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep) timestep = torch.stack(timestep)
# self.model.to(self.device) # self.model.to(self.device)