Added RIFLEx support

This commit is contained in:
DeepBeepMeep 2025-03-02 16:46:52 +01:00
parent 4203fd3732
commit 3731ab70e1
6 changed files with 317 additions and 194 deletions

View File

@ -19,7 +19,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Mar 03, 2025: 👋 Wan2.1GP DeepBeepMeep out of this World version ! Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p
* Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep brings: Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
* Feb 25, 2025: 👋 We've released the inference code and weights of Wan2.1.
* Feb 27, 2025: 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
@ -35,7 +36,6 @@ This version has the following improvements over the original Alibaba model:
- Improved gradio interface with progression bar and more options
- Multiples prompts / multiple generations per prompt
- Support multiple pretrained Loras with 32 GB of RAM or less
- Switch easily between Hunyuan and Fast Hunyuan models and quantized / non quantized models
- Much simpler installation
@ -105,10 +105,28 @@ pip install https://github.com/deepbeepmeep/SageAttention/raw/refs/heads/main/re
## Run the application
### Run a Gradio Server on port 7860 (recommended)
To run the text to video generator (in Low VRAM mode):
```bash
python gradio_server.py
#or
python gradio_server.py --t2v
```
To run the image to video generator (in Low VRAM mode):
```bash
python gradio_server.py --i2v
```
Within the application you can configure which video generator will be launched without specifying a command line switch.
To run the application while loading entirely the diffusion model in VRAM (slightly faster but requires 24 GB of VRAM for a 8 bits quantized 14B model )
```bash
python gradio_server.py --profile 3
```
Please note that diffusion model of Wan2.1GP is extremely VRAM optimized and this will greatly benefit low VRAM systems since the diffusion / denoising step is the longest part of the generation process. However, the VAE encoder (at the beginning of a image 2 video process) and the VAE decoder (at the end of any video process) is only 20% lighter and it will require temporarly 22 GB of VRAM for a 720p generation and 12 GB of VRAM for a 480p generation. Therefore if you have less than these numbers, you may experience slow down at the begining and at the end of the generation process due to pytorch VRAM offloading.
### Loras support
@ -131,7 +149,8 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
### Command line parameters for Gradio Server
--profile no : default (4) : no of profile between 1 and 5\
--i2v : launch the image to video generator\
--t2v : launch the text to video generator\
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
--lora-preset preset : name of preset gile (without the extension) to preload
@ -141,6 +160,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
--open-browser : open automatically Browser when launching Gradio Server\
--compile : turn on pytorch compilation\
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
--profile no : default (4) : no of profile between 1 and 5\
### Profiles (for power users only)
You can choose between 5 profiles, these will try to leverage the most your hardware, but have little impact for HunyuanVideo GP:

View File

@ -288,7 +288,7 @@ if __name__ == '__main__':
# resolution = '720P'
resolution = '480P'
load_model(resolution)
load_i2v_model(resolution)
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":

View File

@ -18,6 +18,8 @@ from wan.utils.utils import cache_video
from wan.modules.attention import get_attention_modes
import torch
import gc
import traceback
def _parse_args():
parser = argparse.ArgumentParser(
@ -752,7 +754,7 @@ def generate_video(
state["in_progress"] = True
state["selected"] = 0
enable_riflex = RIFLEx_setting == 0 and video_length > (5* 24) or RIFLEx_setting == 1
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (5* 16) or RIFLEx_setting == 1
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
@ -810,7 +812,8 @@ def generate_video(
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback
callback=callback,
enable_RIFLEx = enable_RIFLEx
)
else:
@ -824,9 +827,10 @@ def generate_video(
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback
callback=callback,
enable_RIFLEx = enable_RIFLEx
)
except:
except Exception as e:
gen_in_progress = False
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
@ -838,7 +842,21 @@ def generate_video(
gc.collect()
torch.cuda.empty_cache()
s = str(e)
keyword_list = ["vram", "VRAM", "memory", "triton", "cuda", "allocat"]
VRAM_crash= False
if any( keyword in s for keyword in keyword_list):
VRAM_crash = True
else:
stack = traceback.extract_stack(f=None, limit=5)
for frame in stack:
if any( keyword in frame.name for keyword in keyword_list):
VRAM_crash = True
break
if VRAM_crash:
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
else:
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
if samples != None:
@ -949,7 +967,7 @@ def create_demo():
else:
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B> the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
if use_image2video and False:
pass
@ -959,7 +977,7 @@ def create_demo():
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
gr.Markdown("- 1280 x 720 with a 14B model: 192 frames (8s): 11 GB of VRAM")
gr.Markdown("Note that the VAE stages (encoding / decoding at image2video ) or just the decoding at text2video will create a temporary VRAM peak (up to 12GB for 420P and 22 GB for 720P)")
gr.Markdown("It is not recommmended to generate a video longer than 8s even if there is still some VRAM left as some artifact may appear")
gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
@ -1092,12 +1110,24 @@ def create_demo():
with gr.Row():
if use_image2video:
resolution = gr.Dropdown(
choices=[
# 720p
("720p", "1280x720"),
("480p", "832x480"),
],
value="832x480",
label="Resolution (video will have the same height / width ratio than the original image)"
)
else:
resolution = gr.Dropdown(
choices=[
# 720p
("1280x720 (16:9, 720p)", "1280x720"),
("720x1280 (9:16, 720p)", "720x1280"),
("1024x1024 (4:3, 720p, T2V only)", "1024x024"),
("1024x1024 (4:3, 720p)", "1024x024"),
# ("832x1104 (3:4, 720p)", "832x1104"),
# ("960x960 (1:1, 720p)", "960x960"),
# 480p
@ -1125,7 +1155,7 @@ def create_demo():
with gr.Row(visible= len(loras)>0):
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
with gr.Column(scale=5):
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=False, choices= lset_choices, value=default_lora_preset)
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
with gr.Column(scale=1):
# with gr.Column():
with gr.Row(height=17):

View File

@ -143,7 +143,9 @@ class WanI2V:
n_prompt="",
seed=-1,
offload_model=True,
callback = None
callback = None,
enable_RIFLEx = False
):
r"""
Generates video frames from input image and text prompt using diffusion process.
@ -262,7 +264,6 @@ class WanI2V:
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
@ -288,11 +289,14 @@ class WanI2V:
# sample videos
latent = noise
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self
}
@ -301,6 +305,7 @@ class WanI2V:
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
'freqs' : freqs,
'pipeline' : self
}

View File

@ -6,6 +6,8 @@ import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
import numpy as np
from typing import Union,Optional
from .attention import pay_attention
@ -25,7 +27,49 @@ def sinusoidal_embedding_1d(dim, position):
return x
# @amp.autocast(enabled=False)
def identify_k( b: float, d: int, N: int):
"""
This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
Args:
b (`float`): The base frequency for RoPE.
d (`int`): Dimension of the frequency tensor
N (`int`): the first observed repetition frame in latent space
Returns:
k (`int`): the index of intrinsic frequency component
N_k (`int`): the period of intrinsic frequency component in latent space
Example:
In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
k, N_k = identify_k(b=256, d=16, N=48)
In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
"""
# Compute the period of each frequency in RoPE according to Eq.(4)
periods = []
for j in range(1, d // 2 + 1):
theta_j = 1.0 / (b ** (2 * (j - 1) / d))
N_j = round(2 * torch.pi / theta_j)
periods.append(N_j)
# Identify the intrinsic frequency whose period is closed to Nsee Eq.(7)
diffs = [abs(N_j - N) for N_j in periods]
k = diffs.index(min(diffs)) + 1
N_k = periods[k-1]
return k, N_k
def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
assert dim % 2 == 0
exponents = torch.arange(0, dim, 2, dtype=torch.float64).div(dim)
inv_theta_pow = 1.0 / torch.pow(theta, exponents)
inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
@ -588,14 +632,6 @@ class WanModel(ModelMixin, ConfigMixin):
self.head = Head(dim, out_dim, patch_size, eps)
# buffers (don't use register_buffer otherwise dtype will be changed in to())
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6))
],
dim=1)
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim)
@ -603,6 +639,29 @@ class WanModel(ModelMixin, ConfigMixin):
# initialize weights
self.init_weights()
# self.freqs = torch.cat([
# rope_params(1024, d - 4 * (d // 6)), #44
# rope_params(1024, 2 * (d // 6)), #42
# rope_params(1024, 2 * (d // 6)) #42
# ],dim=1)
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None):
dim = self.dim
num_heads = self.num_heads
d = dim // num_heads
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
freqs = torch.cat([
rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ), #44
rope_params(1024, 2 * (d // 6)), #42
rope_params(1024, 2 * (d // 6)) #42
],dim=1)
return freqs
def forward(
self,
x,
@ -611,6 +670,7 @@ class WanModel(ModelMixin, ConfigMixin):
seq_len,
clip_fea=None,
y=None,
freqs = None,
pipeline = None,
):
r"""
@ -638,8 +698,8 @@ class WanModel(ModelMixin, ConfigMixin):
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if freqs.device != device:
freqs = freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
@ -683,7 +743,7 @@ class WanModel(ModelMixin, ConfigMixin):
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
freqs=freqs,
context=context,
context_lens=context_lens)

View File

@ -128,7 +128,8 @@ class WanT2V:
n_prompt="",
seed=-1,
offload_model=True,
callback = None
callback = None,
enable_RIFLEx = None
):
r"""
Generates video frames from text prompt using diffusion process.
@ -209,7 +210,6 @@ class WanT2V:
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
@ -235,8 +235,16 @@ class WanT2V:
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'pipeline': self}
# from .modules.model import identify_k
# for nf in range(20, 50):
# k, N_k = identify_k(10000, 44, 26)
# print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 4 if enable_RIFLEx else None )
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
if callback != None:
callback(-1, None)