mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Added Fun InP models support
This commit is contained in:
parent
16b6fbacec
commit
71697bd7c5
11
README.md
11
README.md
@ -19,8 +19,9 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
|
||||
## 🔥 Latest News!!
|
||||
* Mar 20 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
|
||||
* Mar 19 2025: 👋 Wan2.1GP v3.2:
|
||||
* Mar 27 2025: 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model : Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. You can choose any of those models in the Configuration tab. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
|
||||
* Mar 26 2025: 👋 Good news ! Official support for RTX 50xx please check the installation instructions below.
|
||||
* Mar 24 2025: 👋 Wan2.1GP v3.2:
|
||||
- Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
|
||||
Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
|
||||
- Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
|
||||
@ -188,6 +189,10 @@ To run the image to video generator (in Low VRAM mode):
|
||||
```bash
|
||||
python gradio_server.py --i2v
|
||||
```
|
||||
To run the 1.3B Fun InP image to video generator (in Low VRAM mode):
|
||||
```bash
|
||||
python gradio_server.py --i2v-1-3B
|
||||
```
|
||||
|
||||
To be able to input multiple images with the image to video generator:
|
||||
```bash
|
||||
@ -271,6 +276,8 @@ You can define multiple lines of macros. If there is only one macro line, the ap
|
||||
--t2v : launch the text to video generator (default defined in the configuration)\
|
||||
--t2v-14B : launch the 14B model text to video generator\
|
||||
--t2v-1-3B : launch the 1.3B model text to video generator\
|
||||
--i2v-14B : launch the 14B model image to video generator\
|
||||
--i2v-1-3B : launch the Fun InP 1.3B model image to video generator\
|
||||
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
||||
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
||||
--lora-preset preset : name of preset gile (without the extension) to preload
|
||||
|
||||
152
gradio_server.py
152
gradio_server.py
@ -205,6 +205,18 @@ def _parse_args():
|
||||
action="store_true",
|
||||
help="text to video mode 1.3B model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--i2v-1-3B",
|
||||
action="store_true",
|
||||
help="Fun InP image to video mode 1.3B model"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--i2v-14B",
|
||||
action="store_true",
|
||||
help="image to video mode 14B model"
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--compile",
|
||||
@ -257,22 +269,21 @@ def get_lora_dir(i2v):
|
||||
|
||||
root_lora_dir = "loras_i2v" if i2v else "loras"
|
||||
|
||||
if not i2v:
|
||||
if "1.3B" in transformer_filename_t2v:
|
||||
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
|
||||
if os.path.isdir(lora_dir_1_3B ):
|
||||
return lora_dir_1_3B
|
||||
else:
|
||||
lora_dir_14B = os.path.join(root_lora_dir, "14B")
|
||||
if os.path.isdir(lora_dir_14B ):
|
||||
return lora_dir_14B
|
||||
if "1.3B" in (transformer_filename_i2v if i2v else transformer_filename_t2v) :
|
||||
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
|
||||
if os.path.isdir(lora_dir_1_3B ):
|
||||
return lora_dir_1_3B
|
||||
else:
|
||||
lora_dir_14B = os.path.join(root_lora_dir, "14B")
|
||||
if os.path.isdir(lora_dir_14B ):
|
||||
return lora_dir_14B
|
||||
return root_lora_dir
|
||||
|
||||
attention_modes_installed = get_attention_modes()
|
||||
attention_modes_supported = get_supported_attention_modes()
|
||||
args = _parse_args()
|
||||
args.flow_reverse = True
|
||||
|
||||
# torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
lock_ui_attention = False
|
||||
lock_ui_transformer = False
|
||||
lock_ui_compile = False
|
||||
@ -285,7 +296,7 @@ check_loras = args.check_loras ==1
|
||||
advanced = args.advanced
|
||||
|
||||
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
|
||||
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
|
||||
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
||||
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
||||
|
||||
server_config_filename = "gradio_config.json"
|
||||
@ -293,7 +304,7 @@ server_config_filename = "gradio_config.json"
|
||||
if not Path(server_config_filename).is_file():
|
||||
server_config = {"attention_mode" : "auto",
|
||||
"transformer_filename": transformer_choices_t2v[0],
|
||||
"transformer_filename_i2v": transformer_choices_i2v[1], ########
|
||||
"transformer_filename_i2v": transformer_choices_i2v[1],
|
||||
"text_encoder_filename" : text_encoder_choices[1],
|
||||
"save_path": os.path.join(os.getcwd(), "gradio_outputs"),
|
||||
"compile" : "",
|
||||
@ -363,7 +374,7 @@ def get_default_settings(filename, i2v):
|
||||
return ui_defaults
|
||||
|
||||
transformer_filename_t2v = server_config["transformer_filename"]
|
||||
transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1]) ########
|
||||
transformer_filename_i2v = server_config.get("transformer_filename_i2v", transformer_choices_i2v[1])
|
||||
|
||||
text_encoder_filename = server_config["text_encoder_filename"]
|
||||
attention_mode = server_config["attention_mode"]
|
||||
@ -395,11 +406,22 @@ if args.t2v_14B:
|
||||
transformer_filename_t2v = transformer_choices_t2v[2]
|
||||
lock_ui_transformer = False
|
||||
|
||||
if args.i2v_14B:
|
||||
use_image2video = True
|
||||
if not "14B" in transformer_filename_i2v:
|
||||
transformer_filename_i2v = transformer_choices_t2v[3]
|
||||
lock_ui_transformer = False
|
||||
|
||||
if args.t2v_1_3B:
|
||||
transformer_filename_t2v = transformer_choices_t2v[0]
|
||||
use_image2video = False
|
||||
lock_ui_transformer = False
|
||||
|
||||
if args.i2v_1_3B:
|
||||
transformer_filename_i2v = transformer_choices_i2v[4]
|
||||
use_image2video = True
|
||||
lock_ui_transformer = False
|
||||
|
||||
only_allow_edit_in_advanced = False
|
||||
lora_preselected_preset = args.lora_preset
|
||||
lora_preselected_preset_for_i2v = use_image2video
|
||||
@ -701,7 +723,10 @@ def get_default_flow(filename, i2v):
|
||||
|
||||
|
||||
def get_model_name(model_filename):
|
||||
if "image" in model_filename:
|
||||
if "Fun" in model_filename:
|
||||
model_name = "Fun InP image2video"
|
||||
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
||||
elif "image" in model_filename:
|
||||
model_name = "Wan2.1 image2video"
|
||||
model_name += " 720p" if "720p" in model_filename else " 480p"
|
||||
else:
|
||||
@ -752,7 +777,7 @@ def apply_changes( state,
|
||||
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
||||
server_config = {"attention_mode" : attention_choice,
|
||||
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
||||
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
||||
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
|
||||
"text_encoder_filename" : text_encoder_choices[text_encoder_choice],
|
||||
"save_path" : save_path_choice,
|
||||
"compile" : compile_choice,
|
||||
@ -859,9 +884,9 @@ def refresh_gallery(state, txt):
|
||||
prompts_max = state.get("prompts_max",0)
|
||||
prompt_no = state.get("prompt_no",0)
|
||||
if prompts_max >1 :
|
||||
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
|
||||
label = f"Prompt ({prompt_no+1}/{prompts_max}) of Video being Generated"
|
||||
else:
|
||||
label = f"Current Prompt"
|
||||
label = f"Prompt of Video being Generated"
|
||||
return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label)
|
||||
|
||||
|
||||
@ -998,7 +1023,7 @@ def generate_video(
|
||||
if slg_switch == 0:
|
||||
slg_layers = None
|
||||
if image2video:
|
||||
if "480p" in transformer_filename_i2v and width * height > 848*480:
|
||||
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
||||
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
||||
return
|
||||
|
||||
@ -1007,10 +1032,9 @@ def generate_video(
|
||||
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
||||
return
|
||||
|
||||
else:
|
||||
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
|
||||
gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
|
||||
return
|
||||
if "1.3B" in model_filename and width * height > 848*480:
|
||||
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
||||
return
|
||||
|
||||
offload.shared_state["_attention"] = attn
|
||||
|
||||
@ -1142,7 +1166,7 @@ def generate_video(
|
||||
if len(list_mult_choices_nums ) < len(loras_choices):
|
||||
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
||||
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
||||
pinnedLora = False #profile !=5 #False # # #
|
||||
pinnedLora = profile !=5 #False # # #
|
||||
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, preprocess_sd=preprocess_loras, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
||||
errors = trans._loras_errors
|
||||
if len(errors) > 0:
|
||||
@ -1273,7 +1297,8 @@ def generate_video(
|
||||
slg_start = slg_start/100,
|
||||
slg_end = slg_end/100,
|
||||
cfg_star_switch = cfg_star_switch,
|
||||
cfg_zero_step = cfg_zero_step,
|
||||
cfg_zero_step = cfg_zero_step,
|
||||
add_frames_for_end_image = not "Fun" in transformer_filename_i2v
|
||||
)
|
||||
|
||||
else:
|
||||
@ -2111,7 +2136,7 @@ def generate_video_tab(image2video=False):
|
||||
[state],
|
||||
[output , abort_btn, generate_btn, onemore_btn, gen_info]
|
||||
)
|
||||
return loras_choices, lset_name, header, state
|
||||
return loras_column, loras_choices, presets_column, lset_name, header, state
|
||||
|
||||
def generate_configuration_tab():
|
||||
state_dict = {}
|
||||
@ -2138,7 +2163,10 @@ def generate_configuration_tab():
|
||||
("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 0),
|
||||
("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 1),
|
||||
("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 2),
|
||||
("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 3),
|
||||
("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 3),
|
||||
("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 4),
|
||||
("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 5),
|
||||
("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 6),
|
||||
],
|
||||
value= index,
|
||||
label="Transformer model for Image to Video",
|
||||
@ -2274,7 +2302,7 @@ def generate_configuration_tab():
|
||||
)
|
||||
|
||||
def generate_about_tab():
|
||||
gr.Markdown("<H2>Waw2.1GP - Wan 2.1 model for the GPU Poor by <B>DeepBeepMeep</B> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>GitHub</A>)</H2>")
|
||||
gr.Markdown("<H2>Wan2.1GP - Wan 2.1 model for the GPU Poor by <B>DeepBeepMeep</B> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>GitHub</A>)</H2>")
|
||||
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
|
||||
gr.Markdown("Many thanks to:")
|
||||
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
|
||||
@ -2293,29 +2321,51 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
||||
new_i2v = evt.index == 1
|
||||
use_image2video = new_i2v
|
||||
|
||||
if new_t2v:
|
||||
lora_model_filename = t2v_state["loras_model"]
|
||||
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
||||
if new_t2v or new_i2v:
|
||||
state = i2v_state if new_i2v else t2v_state
|
||||
lora_model_filename = state["loras_model"]
|
||||
model_filename = model_needed(new_i2v)
|
||||
if ("1.3B" in model_filename and not "1.3B" in lora_model_filename or "14B" in model_filename and not "14B" in lora_model_filename):
|
||||
lora_dir = get_lora_dir(new_i2v)
|
||||
loras, loras_names, loras_presets, _, _, _, _ = setup_loras(new_i2v, None, lora_dir, lora_preselected_preset, None)
|
||||
t2v_state["loras"] = loras
|
||||
t2v_state["loras_names"] = loras_names
|
||||
t2v_state["loras_presets"] = loras_presets
|
||||
t2v_state["loras_model"] = transformer_filename_t2v
|
||||
state["loras"] = loras
|
||||
state["loras_names"] = loras_names
|
||||
state["loras_presets"] = loras_presets
|
||||
state["loras_model"] = model_filename
|
||||
|
||||
t2v_advanced = t2v_state["advanced"]
|
||||
advanced = state["advanced"]
|
||||
new_loras_choices = [(name, str(i)) for i, name in enumerate(loras_names)]
|
||||
lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(t2v_advanced), "")]
|
||||
return [
|
||||
gr.Dropdown(choices=new_loras_choices, visible=len(loras_names)>0, value=[]),
|
||||
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(t2v_advanced), visible=len(loras_names)>0),
|
||||
t2v_header,
|
||||
gr.Dropdown(),
|
||||
gr.Dropdown(),
|
||||
i2v_header,
|
||||
]
|
||||
return [gr.Dropdown(), gr.Dropdown(), t2v_header,
|
||||
gr.Dropdown(), gr.Dropdown(), i2v_header]
|
||||
lset_choices = [(preset, preset) for preset in loras_presets] + [(get_new_preset_msg(advanced), "")]
|
||||
visible = len(loras_names)>0
|
||||
if new_t2v:
|
||||
return [
|
||||
gr.Column(visible= visible),
|
||||
gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
|
||||
gr.Column(visible= visible),
|
||||
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
||||
t2v_header,
|
||||
gr.Column(),
|
||||
gr.Dropdown(),
|
||||
gr.Column(),
|
||||
gr.Dropdown(),
|
||||
i2v_header,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
gr.Column(),
|
||||
gr.Dropdown(),
|
||||
gr.Column(),
|
||||
gr.Dropdown(),
|
||||
t2v_header,
|
||||
gr.Column(visible= visible),
|
||||
gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
|
||||
gr.Column(visible= visible),
|
||||
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
||||
i2v_header,
|
||||
]
|
||||
|
||||
return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header,
|
||||
gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header]
|
||||
|
||||
|
||||
def create_demo():
|
||||
@ -2336,7 +2386,7 @@ def create_demo():
|
||||
}
|
||||
"""
|
||||
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.3 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
||||
|
||||
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
||||
@ -2350,9 +2400,9 @@ def create_demo():
|
||||
|
||||
with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
|
||||
with gr.Tab("Text To Video", id="t2v") as t2v_tab:
|
||||
t2v_loras_choices, t2v_lset_name, t2v_header, t2v_state = generate_video_tab()
|
||||
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_state = generate_video_tab()
|
||||
with gr.Tab("Image To Video", id="i2v") as i2v_tab:
|
||||
i2v_loras_choices, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True)
|
||||
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True)
|
||||
if not args.lock_config:
|
||||
with gr.Tab("Configuration"):
|
||||
generate_configuration_tab()
|
||||
@ -2362,8 +2412,8 @@ def create_demo():
|
||||
fn=on_tab_select,
|
||||
inputs=[t2v_state, i2v_state],
|
||||
outputs=[
|
||||
t2v_loras_choices, t2v_lset_name, t2v_header,
|
||||
i2v_loras_choices, i2v_lset_name, i2v_header
|
||||
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header,
|
||||
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header
|
||||
]
|
||||
)
|
||||
return demo
|
||||
|
||||
@ -117,7 +117,8 @@ class WanI2V:
|
||||
logging.info(f"Creating WanModel from {model_filename}")
|
||||
from mmgp import offload
|
||||
|
||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False) #forcedConfigPath= "ckpts/config2.json",
|
||||
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
@ -169,6 +170,7 @@ class WanI2V:
|
||||
slg_end = 1.0,
|
||||
cfg_star_switch = True,
|
||||
cfg_zero_step = 5,
|
||||
add_frames_for_end_image = True
|
||||
):
|
||||
r"""
|
||||
Generates video frames from input image and text prompt using diffusion process.
|
||||
@ -212,8 +214,9 @@ class WanI2V:
|
||||
if any_end_frame:
|
||||
any_end_frame = True
|
||||
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
|
||||
frame_num +=1
|
||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||
if add_frames_for_end_image:
|
||||
frame_num +=1
|
||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||
|
||||
h, w = img.shape[1:]
|
||||
aspect_ratio = h / w
|
||||
@ -237,7 +240,11 @@ class WanI2V:
|
||||
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
||||
if any_end_frame:
|
||||
msk[:, 1: -1] = 0
|
||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
|
||||
if add_frames_for_end_image:
|
||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
|
||||
else:
|
||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
||||
|
||||
else:
|
||||
msk[:, 1:] = 0
|
||||
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
||||
@ -283,7 +290,7 @@ class WanI2V:
|
||||
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
||||
], dim=1).to(self.device)
|
||||
|
||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame)[0]
|
||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||
y = torch.concat([msk, lat_y])
|
||||
|
||||
@contextmanager
|
||||
@ -441,9 +448,9 @@ class WanI2V:
|
||||
|
||||
if self.rank == 0:
|
||||
# x0 = [lat_y]
|
||||
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame)[0]
|
||||
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||
|
||||
if any_end_frame:
|
||||
if any_end_frame and add_frames_for_end_image:
|
||||
# video[:, -1:] = img2_interpolated
|
||||
video = video[:, :-1]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user