diff --git a/README.md b/README.md
index 6207f6a..9523cd2 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@
## 🔥 Latest News!!
+* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Windo siding section below)
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
@@ -302,18 +303,22 @@ There is also a guide that describes the various combination of hints (https://g
It seems you will get better results if you turn on "Skip Layer Guidance" with its default configuration
-### VACE Slidig Window
-With this mode (that works for the moment only with Vace) you can merge mutiple Videos to form a very long video (up to 1 min). What is this very nice a about this feature is that the resulting video can be driven by the same control video. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
+### VACE and Sky Reels v2 Diffusion Forcing Slidig Window
+With this mode (that works for the moment only with Vace and Sky Reels v2) you can merge mutiple Videos to form a very long video (up to 1 min).
-To turn on sliding window, you need to go in the Advanced Settings Tab *Sliding Window* and set the iteration number to a number greater than 1. This number corresponds to the default number of windows. You can still increase the number during the genreation by clicking the "One More Sample, Please !" button.
+When combined with Vace this feature can use the same control video to generate the full Video that results from concatenining the different windows. For instance the first 0-4s of the control video will be used to generate the first window then the next 4-8s of the control video will be used to generate the second window, and so on. So if your control video contains a person walking, your generate video could contain up to one minute of this person walking.
-Each window duration will be set by the *Number of frames (16 = 1s)* form field. However the actual number of frames generated by each iteration will be less, because the *overlap frames* and *discard last frames*:
-- *overlap frames* : the first frames ofa new window are filled with last frames of the previous window in order to ensure continuity between the two windows
-- *discard last frames* : quite often the last frames of a window have a worse quality. You decide here how many ending frames of a new window should be dropped.
+When combined with Sky Reels V2, you can extend an existing video indefinetely.
-Number of Generated = [Number of iterations] * ([Number of frames] - [Overlap Frames] - [Discard Last Frames]) + [Overlap Frames]
+Sliding Windows are turned on by default and are triggered as soon as you try to generate a Video longer than the Window Size. You can go in the Advanced Settings Tab *Sliding Window* to set this Window Size. You can make the Video even longer during the generation process by adding one more Window to generate each time you click "Extend the Video Sample, Please !" button.
-Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
+Although the window duration is set by the *Sliding Window Size* form field, the actual number of frames generated by each iteration will be less, because of the *overlap frames* and *discard last frames*:
+- *overlap frames* : the first frames of a new window are filled with last frames of the previous window in order to ensure continuity between the two windows
+- *discard last frames* : quite often (Vace model Only) the last frames of a window have a worse quality. You can decide here how many ending frames of a new window should be dropped.
+s
+Number of Generated Frames = [Number of Windows - 1] * ([Window Size] - [Overlap Frames] - [Discard Last Frames]) + [Window Size]
+
+Experimental: if your prompt is broken into multiple lines (each line separated by a carriage return), then each line of the prompt will be used for a new window. If there are more windows to generate than prompt lines, the last prompt line will be repeated.
### Command line parameters for Gradio Server
--i2v : launch the image to video generator\
diff --git a/requirements.txt b/requirements.txt
index c8eb6b2..8dc178c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,7 @@ ftfy
dashscope
imageio-ffmpeg
# flash_attn
-gradio>=5.0.0
+gradio==5.23.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
diff --git a/wan/__init__.py b/wan/__init__.py
index df36ebe..54004dd 100644
--- a/wan/__init__.py
+++ b/wan/__init__.py
@@ -1,3 +1,4 @@
from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
+from .diffusion_forcing import DTT2V
\ No newline at end of file
diff --git a/wan/image2video.py b/wan/image2video.py
index 99e0959..da48e67 100644
--- a/wan/image2video.py
+++ b/wan/image2video.py
@@ -352,7 +352,7 @@ class WanI2V:
# self.model.to(self.device)
if callback != None:
- callback(-1, True)
+ callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
@@ -426,7 +426,7 @@ class WanI2V:
del timestep
if callback is not None:
- callback(i, False)
+ callback(i, latent, False)
x0 = [latent.to(self.device, dtype=self.dtype)]
diff --git a/wan/modules/model.py b/wan/modules/model.py
index 37b8989..5346fac 100644
--- a/wan/modules/model.py
+++ b/wan/modules/model.py
@@ -10,6 +10,7 @@ import numpy as np
from typing import Union,Optional
from mmgp import offload
from .attention import pay_attention
+from torch.backends.cuda import sdp_kernel
__all__ = ['WanModel']
@@ -27,6 +28,10 @@ def sinusoidal_embedding_1d(dim, position):
return x
+def reshape_latent(latent, latent_frames):
+ if latent_frames == latent.shape[0]:
+ return latent
+ return latent.reshape(latent_frames, -1, latent.shape[-1] )
def identify_k( b: float, d: int, N: int):
@@ -167,7 +172,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
- def forward(self, xlist, grid_sizes, freqs):
+ def forward(self, xlist, grid_sizes, freqs, block_mask = None):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -190,12 +195,44 @@ class WanSelfAttention(nn.Module):
del x
qklist = [q,k]
del q,k
+
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
qkv_list = [q,k,v]
del q,k,v
- x = pay_attention(
- qkv_list,
- window_size=self.window_size)
+ if block_mask == None:
+ x = pay_attention(
+ qkv_list,
+ window_size=self.window_size)
+ else:
+ with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ x = (
+ torch.nn.functional.scaled_dot_product_attention(
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
+ )
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ # if not self._flag_ar_attention:
+ # q = rope_apply(q, grid_sizes, freqs)
+ # k = rope_apply(k, grid_sizes, freqs)
+ # x = flash_attention(q=q, k=k, v=v, window_size=self.window_size)
+ # else:
+ # q = rope_apply(q, grid_sizes, freqs)
+ # k = rope_apply(k, grid_sizes, freqs)
+ # q = q.to(torch.bfloat16)
+ # k = k.to(torch.bfloat16)
+ # v = v.to(torch.bfloat16)
+
+ # with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ # x = (
+ # torch.nn.functional.scaled_dot_product_attention(
+ # q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask
+ # )
+ # .transpose(1, 2)
+ # .contiguous()
+ # )
+
# output
x = x.flatten(2)
x = self.o(x)
@@ -360,7 +397,8 @@ class WanAttentionBlock(nn.Module):
context,
hints= None,
context_scale=1.0,
- cam_emb= None
+ cam_emb= None,
+ block_mask = None
):
r"""
Args:
@@ -381,13 +419,14 @@ class WanAttentionBlock(nn.Module):
hint = self.vace(hints, x, **kwargs)
else:
hint = self.vace(hints, None, **kwargs)
-
+ latent_frames = e.shape[0]
e = (self.modulation + e).chunk(6, dim=1)
-
# self-attention
x_mod = self.norm1(x)
+ x_mod = reshape_latent(x_mod , latent_frames)
x_mod *= 1 + e[1]
x_mod += e[0]
+ x_mod = reshape_latent(x_mod , 1)
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
@@ -397,12 +436,13 @@ class WanAttentionBlock(nn.Module):
xlist = [x_mod]
del x_mod
- y = self.self_attn( xlist, grid_sizes, freqs)
+ y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
if cam_emb != None:
y = self.projector(y)
- # x = x + gate_msa * self.projector(self.self_attn(input_x, freqs))
+ x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[2])
+ x, y = reshape_latent(x , 1), reshape_latent(y , 1)
del y
y = self.norm3(x)
ylist= [y]
@@ -410,8 +450,10 @@ class WanAttentionBlock(nn.Module):
x += self.cross_attn(ylist, context)
y = self.norm2(x)
+ y = reshape_latent(y , latent_frames)
y *= 1 + e[4]
y += e[3]
+ y = reshape_latent(y , 1)
ffn = self.ffn[0]
gelu = self.ffn[1]
@@ -428,7 +470,9 @@ class WanAttentionBlock(nn.Module):
del mlp_chunk
y = y.view(y_shape)
+ x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5])
+ x, y = reshape_latent(x , 1), reshape_latent(y , 1)
if hint is not None:
if context_scale == 1:
@@ -500,10 +544,14 @@ class Head(nn.Module):
"""
# assert e.dtype == torch.float32
dtype = x.dtype
+
+ latent_frames = e.shape[0]
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
x = self.norm(x).to(dtype)
+ x = reshape_latent(x , latent_frames)
x *= (1 + e[1])
x += e[0]
+ x = reshape_latent(x , 1)
x = self.head(x)
return x
@@ -552,7 +600,8 @@ class WanModel(ModelMixin, ConfigMixin):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
- recammaster = False
+ recammaster = False,
+ inject_sample_info = False,
):
r"""
Initialize the diffusion model backbone.
@@ -609,6 +658,10 @@ class WanModel(ModelMixin, ConfigMixin):
self.qk_norm = qk_norm
self.cross_attn_norm = cross_attn_norm
self.eps = eps
+ self.num_frame_per_block = 1
+ self.flag_causal_attention = False
+ self.block_mask = None
+ self.inject_sample_info = inject_sample_info
# embeddings
self.patch_embedding = nn.Conv3d(
@@ -617,6 +670,10 @@ class WanModel(ModelMixin, ConfigMixin):
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
nn.Linear(dim, dim))
+ if inject_sample_info:
+ self.fps_embedding = nn.Embedding(2, dim)
+ self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6))
+
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
@@ -678,12 +735,13 @@ class WanModel(ModelMixin, ConfigMixin):
block.projector.bias = nn.Parameter(torch.zeros(dim))
- def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
+ def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
e_list = []
for t in timesteps:
t = torch.stack([t])
- e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
+ time_emb = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim
+ e_list.append(time_emb)
best_threshold = 0.01
best_diff = 1000
@@ -695,16 +753,13 @@ class WanModel(ModelMixin, ConfigMixin):
nb_steps = 0
diff = 1000
for i, t in enumerate(timesteps):
- skip = False
+ skip = False
if not (i<=start_step or i== len(timesteps)):
- accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
- # self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
-
+ accumulated_rel_l1_distance += abs(rescale_func(((e_list[i]-e_list[i-1]).abs().mean() / e_list[i-1].abs().mean()).cpu().item()))
if accumulated_rel_l1_distance < threshold:
skip = True
else:
accumulated_rel_l1_distance = 0
- previous_modulated_input = e_list[i]
if not skip:
nb_steps += 1
signed_diff = target_nb_steps - nb_steps
@@ -739,6 +794,9 @@ class WanModel(ModelMixin, ConfigMixin):
slg_layers=None,
callback = None,
cam_emb: torch.Tensor = None,
+ fps = None,
+ causal_block_size = 1,
+ causal_attention = False,
):
if self.model_type == 'i2v':
@@ -752,26 +810,53 @@ class WanModel(ModelMixin, ConfigMixin):
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
# grid_sizes = torch.stack(
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0]
+ if causal_attention : #causal_block_size > 0:
+ frame_num = embed_sizes[0]
+ height = embed_sizes[1]
+ width = embed_sizes[2]
+ block_num = frame_num // causal_block_size
+ range_tensor = torch.arange(block_num).view(-1, 1)
+ range_tensor = range_tensor.repeat(1, causal_block_size).flatten()
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
+ causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x[0].device)
+ causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
+ causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
+ block_mask = causal_mask.unsqueeze(0).unsqueeze(0)
+ del causal_mask
offload.shared_state["embed_sizes"] = embed_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps
-
x = [u.flatten(2).transpose(1, 2) for u in x]
x = x[0]
- # time embeddings
+ if t.dim() == 2:
+ b, f = t.shape
+ _flag_df = True
+ else:
+ _flag_df = False
+
e = self.time_embedding(
- sinusoidal_embedding_1d(self.freq_dim, t))
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
+ ) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
+ if self.inject_sample_info:
+ fps = torch.tensor(fps, dtype=torch.long, device=device)
+
+ fps_emb = self.fps_embedding(fps).float()
+ if _flag_df:
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
+ else:
+ e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
+
# context
context = self.text_embedding(
torch.stack([
@@ -833,7 +918,7 @@ class WanModel(ModelMixin, ConfigMixin):
self.accumulated_rel_l1_distance = 0
else:
rescale_func = np.poly1d(self.coefficients)
- self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
+ self.accumulated_rel_l1_distance += abs(rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()))
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
self.teacache_skipped_steps += 1
@@ -858,7 +943,7 @@ class WanModel(ModelMixin, ConfigMixin):
for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx
if callback != None:
- callback(-1, False, True)
+ callback(-1, None, False, True)
if pipeline._interrupt:
if joint_pass:
return None, None
diff --git a/wan/modules/sage2_core.py b/wan/modules/sage2_core.py
index 2b7da50..e023a28 100644
--- a/wan/modules/sage2_core.py
+++ b/wan/modules/sage2_core.py
@@ -1075,13 +1075,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
q_size = q.size()
+ kv_len = k.size(seq_dim)
q_device = q.device
del q,k
# pad v to multiple of 128
# TODO: modify per_channel_fp8 kernel to handle this
- kv_len = k.size(seq_dim)
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
if v_pad_len > 0:
if tensor_layout == "HND":
diff --git a/wan/text2video.py b/wan/text2video.py
index 725d51f..45f4149 100644
--- a/wan/text2video.py
+++ b/wan/text2video.py
@@ -49,40 +49,14 @@ class WanT2V:
config,
checkpoint_dir,
rank=0,
- t5_fsdp=False,
- dit_fsdp=False,
- use_usp=False,
- t5_cpu=False,
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
):
- r"""
- Initializes the Wan text-to-video generation model components.
-
- Args:
- config (EasyDict):
- Object containing model parameters initialized from config.py
- checkpoint_dir (`str`):
- Path to directory containing model checkpoints
- device_id (`int`, *optional*, defaults to 0):
- Id of target GPU device
- rank (`int`, *optional*, defaults to 0):
- Process rank for distributed training
- t5_fsdp (`bool`, *optional*, defaults to False):
- Enable FSDP sharding for T5 model
- dit_fsdp (`bool`, *optional*, defaults to False):
- Enable FSDP sharding for DiT model
- use_usp (`bool`, *optional*, defaults to False):
- Enable distribution strategy of USP.
- t5_cpu (`bool`, *optional*, defaults to False):
- Whether to place T5 model on CPU. Only works without t5_fsdp.
- """
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
- self.t5_cpu = t5_cpu
self.dtype = dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
@@ -419,9 +393,9 @@ class WanT2V:
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
- arg_c = {'context': context, 'freqs': freqs, 'pipeline': self}
- arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self}
- arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self}
+ arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
+ arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
+ arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
@@ -438,7 +412,7 @@ class WanT2V:
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None:
- callback(-1, True)
+ callback(-1, None, True)
for i, t in enumerate(tqdm(timesteps)):
if target_camera != None:
latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
@@ -494,7 +468,7 @@ class WanT2V:
del temp_x0
if callback is not None:
- callback(i, False)
+ callback(i, latents[0], False)
x0 = latents
diff --git a/wgp.py b/wgp.py
index 141cee0..3285d8e 100644
--- a/wgp.py
+++ b/wgp.py
@@ -115,6 +115,14 @@ def pil_to_base64_uri(pil_image, format="png", quality=75):
print(f"Error converting PIL to base64: {e}")
return None
+def is_integer(n):
+ try:
+ float(n)
+ except ValueError:
+ return False
+ else:
+ return float(n).is_integer()
+
def process_prompt_and_add_tasks(state, model_choice):
@@ -172,17 +180,67 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return
- sliding_window_repeat = inputs["sliding_window_repeat"]
- sliding_window = sliding_window_repeat > 0
+ if "diffusion_forcing" in model_filename or "Vace" in model_filename:
+ video_length = inputs["video_length"]
+ sliding_window_size = inputs["sliding_window_size"]
+ if video_length > sliding_window_size:
+ gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
- if "recam" in model_filename:
+
+ if "diffusion_forcing" in model_filename:
+ image_start = inputs["image_start"]
+ video_source = inputs["video_source"]
+ keep_frames_video_source = inputs["keep_frames_video_source"]
+ image_prompt_type = inputs["image_prompt_type"]
+
+ if len(keep_frames_video_source) > 0:
+ if not is_integer(keep_frames_video_source):
+ gr.Info("The number of frames to keep must be an integer")
+ return
+
+ if "V" in image_prompt_type:
+ if video_source == None or len(video_source) == 0:
+ gr.Info("You must provide a Video to continue")
+ return
+ image_start = None
+
+ if "S" in image_prompt_type:
+ if image_start == None :
+ gr.Info("You must provide a Start Image")
+ return
+ if len(image_start) > 1:
+ gr.Info("Only one Start Image is supported for the moment")
+ return
+ if isinstance(image_start[0][0], str) :
+ gr.Info("Start Image should be an Image")
+ return
+
+ image_start = [ convert_image(tup[0]) for tup in image_start ]
+ video_source = None
+
+ if "T" in image_prompt_type:
+ image_start = None
+ video_source = None
+
+ if len(prompts) > 0:
+ prompts = ["\n".join(prompts)]
+
+ for single_prompt in prompts:
+ extra_inputs = {
+ "prompt" : single_prompt,
+ "image_start" : image_start,
+ "video_source" : video_source,
+ }
+ inputs.update(extra_inputs)
+ add_video_task(**inputs)
+ elif "recam" in model_filename:
video_source = inputs["video_source"]
if video_source == None:
gr.Info("You must provide a Source Video")
return
- frames = get_resampled_video(video_source, 0, 81)
+ frames = get_resampled_video(video_source, 0, 81, 16)
if len(frames)<81:
- gr.Info("Recammaster source video should be at least 81 frames one the resampling at 16 fps has been done")
+ gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done")
return
for single_prompt in prompts:
extra_inputs = {
@@ -198,14 +256,6 @@ def process_prompt_and_add_tasks(state, model_choice):
video_guide = inputs["video_guide"]
video_mask = inputs["video_mask"]
- if sliding_window:
- if inputs["repeat_generation"]!=1:
- gr.Info("Only one Video generated per Prompt is supported when Sliding windows is used")
- return
- if inputs["sliding_window_overlap"]>=inputs["video_length"] :
- gr.Info("The number of frames of the Sliding Window Overlap must be less than the Number of Frames to Generate")
- return
-
if "1.3B" in model_filename :
resolution_reformated = str(height) + "*" + str(width)
if not resolution_reformated in VACE_SIZE_CONFIGS:
@@ -214,7 +264,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return
if "I" in video_prompt_type:
if image_refs == None:
- gr.Info("You must provide at one Refererence Image")
+ gr.Info("You must provide at least one Refererence Image")
return
else:
image_refs = None
@@ -231,17 +281,17 @@ def process_prompt_and_add_tasks(state, model_choice):
else:
video_mask = None
if "O" in video_prompt_type :
- keep_frames= inputs["keep_frames"]
+ keep_frames_video_guide= inputs["keep_frames_video_guide"]
video_length = inputs["video_length"]
- if len(keep_frames) ==0:
+ if len(keep_frames_video_guide) ==0:
gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.")
# elif keep_frames >= video_length:
- # gr.Info(f"The number of frames in the control Video to reuse ({keep_frames}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
+ # gr.Info(f"The number of frames in the control Video to reuse ({keep_frames_video_guide}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.")
# return
elif "V" in video_prompt_type:
- keep_frames= inputs["keep_frames"]
+ keep_frames_video_guide= inputs["keep_frames_video_guide"]
video_length = inputs["video_length"]
- _, error = parse_keep_frames(keep_frames, video_length)
+ _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length)
if len(error) > 0:
gr.Info(f"Invalid Keep Frames property: {error}")
return
@@ -254,7 +304,7 @@ def process_prompt_and_add_tasks(state, model_choice):
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1)
- if sliding_window and len(prompts) > 0:
+ if len(prompts) > 0:
prompts = ["\n".join(prompts)]
for single_prompt in prompts:
@@ -275,15 +325,20 @@ def process_prompt_and_add_tasks(state, model_choice):
return
if not "E" in image_prompt_type:
image_end = None
- if isinstance(image_start, list):
- image_start = [ convert_image(tup[0]) for tup in image_start ]
- else:
- image_start = [convert_image(image_start)]
+ if not isinstance(image_start, list):
+ image_start = [image_start]
+ if not all( not isinstance(img[0], str) for img in image_start) :
+ gr.Info("Start Image should be an Image")
+ return
+ image_start = [ convert_image(tup[0]) for tup in image_start ]
+
if image_end != None:
- if isinstance(image_end , list):
- image_end = [ convert_image(tup[0]) for tup in image_end ]
- else:
- image_end = [convert_image(image_end) ]
+ if not isinstance(image_end , list):
+ image_end = [image_end]
+ if not all( not isinstance(img[0], str) for img in image_end) :
+ gr.Info("End Image should be an Image")
+ return
+ image_end = [ convert_image(tup[0]) for tup in image_end ]
if len(image_start) != len(image_end):
gr.Info("The number of start and end images should be the same ")
return
@@ -573,16 +628,29 @@ def save_queue_action(state):
finally:
zip_buffer.close()
-def load_queue_action(filepath, state):
+def load_queue_action(filepath, state, evt:gr.EventData):
global task_id
+
gen = get_gen_info(state)
original_queue = gen.get("queue", [])
+ delete_autoqueue_file = False
+ if evt.target == None:
+
+ if original_queue or not Path(AUTOSAVE_FILENAME).is_file():
+ return
+ print(f"Autoloading queue from {AUTOSAVE_FILENAME}...")
+ filename = AUTOSAVE_FILENAME
+ delete_autoqueue_file = True
+ else:
+ if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
+ print("[load_queue_action] Warning: No valid file selected or file not found.")
+ return update_queue_data(original_queue)
+ filename = filepath.name
+
+
save_path_base = server_config.get("save_path", "outputs")
loaded_cache_dir = os.path.join(save_path_base, "_loaded_queue_cache")
- if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
- print("[load_queue_action] Warning: No valid file selected or file not found.")
- return update_queue_data(original_queue)
newly_loaded_queue = []
max_id_in_file = 0
@@ -590,14 +658,14 @@ def load_queue_action(filepath, state):
local_queue_copy_for_global_ref = None
try:
- print(f"[load_queue_action] Attempting to load queue from: {filepath.name}")
+ print(f"[load_queue_action] Attempting to load queue from: {filename}")
os.makedirs(loaded_cache_dir, exist_ok=True)
print(f"[load_queue_action] Using cache directory: {loaded_cache_dir}")
with tempfile.TemporaryDirectory() as tmpdir:
- with zipfile.ZipFile(filepath.name, 'r') as zf:
+ with zipfile.ZipFile(filename, 'r') as zf:
if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file")
- print(f"[load_queue_action] Extracting {filepath.name} to {tmpdir}")
+ print(f"[load_queue_action] Extracting {filename} to {tmpdir}")
zf.extractall(tmpdir)
print(f"[load_queue_action] Extraction complete.")
@@ -677,8 +745,8 @@ def load_queue_action(filepath, state):
primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params)
- start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list[0] else None
- end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list[0] else None
+ start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if primary_preview_pil_list and primary_preview_pil_list[0] else None
+ end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if secondary_preview_pil_list and secondary_preview_pil_list[0] else None
top_level_start_image = params.get("image_start") or params.get("image_refs")
top_level_end_image = params.get("image_end")
@@ -732,6 +800,11 @@ def load_queue_action(filepath, state):
print("[load_queue_action] Load failed. Returning DataFrame update for original queue.")
return update_queue_data(original_queue)
finally:
+ if delete_autoqueue_file:
+ if os.path.isfile(filename):
+ os.remove(filename)
+ print(f"Clear Queue: Deleted autosave file '{filename}'.")
+
if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name):
if tempfile.gettempdir() in os.path.abspath(filepath.name):
try:
@@ -919,78 +992,6 @@ def autosave_queue():
print(f"Error during autosave: {e}")
traceback.print_exc()
-
-def autoload_queue(state):
- global task_id
- try:
- gen = get_gen_info(state)
- original_queue = gen.get("queue", [])
- except AttributeError:
- print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.")
- return gr.update(visible=False), False, state
-
- loaded_flag = False
- dataframe_update = update_queue_data(original_queue)
-
- if not original_queue and Path(AUTOSAVE_FILENAME).is_file():
- print(f"Autoloading queue from {AUTOSAVE_FILENAME}...")
- class MockFile:
- def __init__(self, name):
- self.name = name
- mock_filepath = MockFile(AUTOSAVE_FILENAME)
- dataframe_update = load_queue_action(mock_filepath, state)
-
- gen = get_gen_info(state)
- loaded_queue_after_action = gen.get("queue", [])
-
- if loaded_queue_after_action:
- print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.")
- loaded_flag = True
- else:
- print("Autoload attempted but queue in state remains empty (file might be empty or invalid).")
- with lock:
- gen["queue"] = []
- gen["prompts_max"] = 0
- update_global_queue_ref([])
- dataframe_update = update_queue_data([])
-
- # need to remove queue otherwise every new tab will be processed it again
- try:
- if os.path.isfile(AUTOSAVE_FILENAME):
- os.remove(AUTOSAVE_FILENAME)
- print(f"Clear Queue: Deleted autosave file '{AUTOSAVE_FILENAME}'.")
- except OSError as e:
- print(f"Clear Queue: Error deleting autosave file '{AUTOSAVE_FILENAME}': {e}")
- gr.Warning(f"Could not delete the autosave file '{AUTOSAVE_FILENAME}'. You may need to remove it manually.")
-
- else:
- if original_queue:
- print("Autoload skipped: Queue is not empty.")
- update_global_queue_ref(original_queue)
- dataframe_update = update_queue_data(original_queue)
- else:
- # print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
- update_global_queue_ref([])
- dataframe_update = update_queue_data([])
-
-
- return dataframe_update, loaded_flag, state
-
-def run_autoload_and_prepare_ui(current_state):
- df_update, loaded_flag, modified_state = autoload_queue(current_state)
- should_start_processing = loaded_flag
- accordion_update = gr.Accordion(open=True) if loaded_flag else gr.update()
- return df_update, gr.update(visible=loaded_flag), accordion_update, should_start_processing, modified_state
-
-def start_processing_if_needed(should_start, current_state):
- if not isinstance(current_state, dict) or 'gen' not in current_state:
- yield "Error: Invalid state received before processing."
- return
- if should_start:
- yield from process_tasks(current_state)
- else:
- yield None
-
def finalize_generation_with_state(current_state):
if not isinstance(current_state, dict) or 'gen' not in current_state:
return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), gr.update(), current_state
@@ -1040,10 +1041,6 @@ def update_queue_data(queue):
update_global_queue_ref(queue)
data = get_queue_table(queue)
- # if len(data) == 0:
- # return gr.HTML(visible=False)
- # else:
- # return gr.HTML(value=data, visible= True)
if len(data) == 0:
return gr.DataFrame(visible=False)
else:
@@ -1365,7 +1362,8 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1
advanced = args.advanced
-transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors"]
+transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors",
+ "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
@@ -1404,10 +1402,10 @@ else:
server_config = json.loads(text)
-model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p"]
+model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
- "flf2v_720p" : "FLF2V_720p" }
+ "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B" }
def get_model_type(model_filename):
@@ -1435,6 +1433,9 @@ def get_model_name(model_filename):
elif "FLF2V" in model_filename:
model_name = "Wan2.1 FLF2V"
model_name += " 720p" if "720p" in model_filename else " 480p"
+ elif "sky_reels2_diffusion_forcing" in model_filename:
+ model_name = "SkyReels2 diffusion forcing"
+ model_name += " 14B" if "14B" in model_filename else " 1.3B"
else:
model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
@@ -1491,6 +1492,14 @@ def get_default_settings(filename):
"slg_start_perc": 10,
"slg_end_perc": 90
}
+
+ if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B"):
+ ui_defaults.update({
+ "guidance_scale": 6.0,
+ "flow_shift": 8,
+ "sliding_window_discard_last_frames" : 0
+ })
+
with open(defaults_filename, "w", encoding="utf-8") as f:
json.dump(ui_defaults, f, indent=4)
else:
@@ -1669,7 +1678,7 @@ for file_name in to_remove:
download_models(transformer_filename, text_encoder_filename)
def sanitize_file_name(file_name, rep =""):
- return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
+ return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep)
def extract_preset(model_filename, lset_name, loras):
loras_choices = []
@@ -1759,14 +1768,14 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
cfg = WAN_CONFIGS['t2v-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
print(f"Loading '{model_filename}' model...")
+ if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B"):
+ model_factory = wan.DTT2V
+ else:
+ model_factory = wan.WanT2V
- wan_model = wan.WanT2V(
+ wan_model = model_factory(
config=cfg,
checkpoint_dir="ckpts",
- rank=0,
- t5_fsdp=False,
- dit_fsdp=False,
- use_usp=False,
model_filename=model_filename,
text_encoder_filename= text_encoder_filename,
quantizeTransformer = quantizeTransformer,
@@ -1922,7 +1931,7 @@ def apply_changes( state,
"compile" : compile_choice,
"profile" : profile_choice,
"vae_config" : vae_config_choice,
- "metadata_choice": metadata_choice,
+ "metadata_type": metadata_choice,
"transformer_quantization" : quantization_choice,
"boost" : boost_choice,
"clear_file_list" : clear_file_list,
@@ -1967,7 +1976,7 @@ def apply_changes( state,
model_transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
model_filename = get_model_filename(model_transformer_type, transformer_quantization)
state["model_filename"] = model_filename
- if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
+ if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list"] for change in changes ):
model_choice = gr.Dropdown()
else:
reload_needed = True
@@ -1995,9 +2004,10 @@ def get_gen_info(state):
state["gen"] = cache
return cache
-def build_callback(state, pipe, progress, status, num_inference_steps):
- def callback(step_idx, force_refresh, read_state = False):
- gen = get_gen_info(state)
+def build_callback(state, pipe, send_cmd, status, num_inference_steps):
+ gen = get_gen_info(state)
+ gen["num_inference_steps"] = num_inference_steps
+ def callback(step_idx, latent, force_refresh, read_state = False, override_num_inference_steps = -1):
refresh_id = gen.get("refresh", -1)
if force_refresh or step_idx >= 0:
pass
@@ -2008,7 +2018,10 @@ def build_callback(state, pipe, progress, status, num_inference_steps):
UI_refresh = state.get("refresh", 0)
if UI_refresh >= refresh_id:
return
-
+ if override_num_inference_steps > 0:
+ gen["num_inference_steps"] = override_num_inference_steps
+
+ num_inference_steps = gen.get("num_inference_steps", 0)
status = gen["progress_status"]
state["refresh"] = refresh_id
if read_state:
@@ -2029,8 +2042,12 @@ def build_callback(state, pipe, progress, status, num_inference_steps):
else:
progress_args = [0, status_msg]
- progress(*progress_args)
- gen["progress_args"] = progress_args
+ # progress(*progress_args)
+ send_cmd("progress", progress_args)
+ if latent != None:
+ send_cmd("preview", latent.to("cpu", non_blocking=True))
+
+ # gen["progress_args"] = progress_args
return callback
def abort_generation(state):
@@ -2042,17 +2059,18 @@ def abort_generation(state):
if wan_model != None:
wan_model._interrupt= True
msg = "Processing Request to abort Current Generation"
+ gen["status"] = msg
gr.Info(msg)
- return msg, gr.Button(interactive= False)
+ return gr.Button(interactive= False)
else:
- return "", gr.Button(interactive= True)
+ return gr.Button(interactive= True)
-def refresh_gallery(state, msg):
+def refresh_gallery(state): #, msg
gen = get_gen_info(state)
- gen["last_msg"] = msg
+ # gen["last_msg"] = msg
file_list = gen.get("file_list", None)
choice = gen.get("selected",0)
in_progress = "in_progress" in gen
@@ -2063,20 +2081,22 @@ def refresh_gallery(state, msg):
queue = gen.get("queue", [])
abort_interactive = not gen.get("abort", False)
if not in_progress or len(queue) == 0:
- return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive)
+ return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= False)
else:
task = queue[0]
start_img_md = ""
end_img_md = ""
prompt = task["prompt"]
params = task["params"]
- if "\n" in prompt and params.get("sliding_window_repeat", 0) > 0:
+ model_filename = params["model_filename"]
+ onemorewindow_visible = "Vace" in model_filename or "diffusion_forcing" in model_filename
+ if "\n" in prompt :
prompts = prompt.split("\n")
- repeat_no= gen.get("repeat_no",1)
- if repeat_no > len(prompts):
- repeat_no = len(prompts)
- repeat_no -= 1
- prompts[repeat_no]="" + prompts[repeat_no] + ""
+ window_no= gen.get("window_no",1)
+ if window_no > len(prompts):
+ window_no = len(prompts)
+ window_no -= 1
+ prompts[window_no]="" + prompts[window_no] + ""
prompt = "
".join(prompts)
start_img_uri = task.get('start_image_data_base64')
@@ -2099,7 +2119,7 @@ def refresh_gallery(state, msg):
html += ""
html_output = gr.HTML(html, visible= True)
- return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive)
+ return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= onemorewindow_visible)
@@ -2153,7 +2173,7 @@ def convert_image(image):
image = image.convert('RGB')
return cast(Image, ImageOps.exif_transpose(image))
-def get_resampled_video(video_in, start_frame, max_frames):
+def get_resampled_video(video_in, start_frame, max_frames, target_fps):
from wan.utils.utils import resample
import decord
@@ -2162,13 +2182,13 @@ def get_resampled_video(video_in, start_frame, max_frames):
fps = reader.get_avg_fps()
- frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=16, start_target_frame= start_frame)
+ frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=target_fps, start_target_frame= start_frame)
frames_list = reader.get_batch(frame_nos)
return frames_list
-def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False):
+def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False, target_fps = 16):
- frames_list = get_resampled_video(video_in, start_frame, max_frames)
+ frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps)
if len(frames_list) == 0:
return None
@@ -2227,14 +2247,8 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr
return torch.stack(torch_frames)
-def parse_keep_frames(keep_frames, video_length):
- def is_integer(n):
- try:
- float(n)
- except ValueError:
- return False
- else:
- return float(n).is_integer()
+
+def parse_keep_frames_video_guide(keep_frames, video_length):
def absolute(n):
if n==0:
@@ -2280,7 +2294,7 @@ def parse_keep_frames(keep_frames, video_length):
def generate_video(
task_id,
- progress,
+ send_cmd,
prompt,
negative_prompt,
resolution,
@@ -2299,14 +2313,15 @@ def generate_video(
image_prompt_type,
image_start,
image_end,
+ model_mode,
+ video_source,
+ keep_frames_video_source,
video_prompt_type,
image_refs,
video_guide,
+ keep_frames_video_guide,
video_mask,
- camera_type,
- video_source,
- keep_frames,
- sliding_window_repeat,
+ sliding_window_size,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
@@ -2323,7 +2338,6 @@ def generate_video(
model_filename
):
-
global wan_model, offloadobj, reload_needed
gen = get_gen_info(state)
@@ -2345,9 +2359,9 @@ def generate_video(
offloadobj.release()
offloadobj = None
gc.collect()
- yield f"Loading model {get_model_name(model_filename)}..."
+ send_cmd("status", f"Loading model {get_model_name(model_filename)}...")
wan_model, offloadobj, trans = load_models(model_filename)
- yield f"Model loaded"
+ send_cmd("status", "Model loaded")
reload_needed= False
if attention_mode == "auto":
@@ -2355,7 +2369,8 @@ def generate_video(
elif attention_mode in attention_modes_supported:
attn = attention_mode
else:
- gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
+ send_cmd("info", f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
+ send_cmd("exit")
return
width, height = resolution.split("x")
@@ -2447,19 +2462,15 @@ def generate_video(
if image2video:
if '480p' in model_filename:
- # teacache_thresholds = [0.13, .19, 0.26]
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
elif '720p' in model_filename:
- teacache_thresholds = [0.18, 0.2 , 0.3]
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
else:
raise gr.Error("Teacache not supported for this model")
else:
if '1.3B' in model_filename:
- # teacache_thresholds= [0.05, 0.07, 0.08]
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
elif '14B' in model_filename:
- # teacache_thresholds = [0.14, 0.15, 0.2]
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
else:
raise gr.Error("Teacache not supported for this model")
@@ -2467,14 +2478,13 @@ def generate_video(
target_camera = None
if "recam" in model_filename:
source_video = preprocess_video("", width=width, height=height,video_in=video_source, max_frames= video_length, start_frame = 0, fit_canvas= True)
- target_camera = camera_type
+ target_camera = model_mode
import random
if seed == None or seed <0:
seed = random.randint(0, 999999999)
global save_path
os.makedirs(save_path, exist_ok=True)
- video_no = 0
abort = False
gc.collect()
torch.cuda.empty_cache()
@@ -2483,307 +2493,364 @@ def generate_video(
gen["prompt"] = prompt
repeat_no = 0
extra_generation = 0
- start_frame = 0
- sliding_window = sliding_window_repeat > 0
+ initial_total_windows = 0
+ max_frames_to_generate = video_length
+ diffusion_forcing = "diffusion_forcing" in model_filename
+ vace = "Vace" in model_filename
+ if diffusion_forcing or vace:
+ reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
+ if diffusion_forcing and source_video != None:
+ video_length += sliding_window_overlap
+ sliding_window = ("Vace" in model_filename or diffusion_forcing) and video_length > sliding_window_size
+
if sliding_window:
- reuse_frames = sliding_window_overlap
- discard_last_frames = sliding_window_discard_last_frames #4
- repeat_generation = sliding_window_repeat
+ discard_last_frames = sliding_window_discard_last_frames
+ left_after_first_window = video_length - sliding_window_size + discard_last_frames
+ initial_total_windows= 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames))
+ video_length = sliding_window_size
prompts = prompt.split("\n")
prompts = [part for part in prompts if len(prompt)>0]
+ else:
+ initial_total_windows = 1
-
+ first_window_video_length = video_length
+ fps = 24 if diffusion_forcing else 16
+
gen["sliding_window"] = sliding_window
- frames_already_processed = None
- pre_video_guide = None
-
- while True:
+ while not abort:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
total_generation = repeat_generation + extra_generation
gen["total_generation"] = total_generation
- if abort or repeat_no >= total_generation:
+ if repeat_no >= total_generation:
break
-
- if "Vace" in model_filename and (repeat_no == 0 or sliding_window):
- if sliding_window:
- prompt = prompts[repeat_no] if repeat_no < len(prompts) else prompts[-1]
-
- # video_prompt_type = video_prompt_type +"G"
- image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
- video_guide_copy = video_guide
- video_mask_copy = video_mask
- if any(process in video_prompt_type for process in ("P", "D", "G")) :
- prompts_max = gen["prompts_max"]
-
- status = get_generation_status(prompt_no, prompts_max, 1, 1, sliding_window)
- preprocess_type = None
- if "P" in video_prompt_type :
- progress_args = [0, status + " - Extracting Open Pose Information"]
- preprocess_type = "pose"
- elif "D" in video_prompt_type :
- progress_args = [0, status + " - Extracting Depth Information"]
- preprocess_type = "depth"
- elif "G" in video_prompt_type :
- progress_args = [0, status + " - Extracting Gray Level Information"]
- preprocess_type = "gray"
-
- if preprocess_type != None :
- progress(*progress_args )
- gen["progress_args"] = progress_args
- video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if repeat_no ==0 else video_length - reuse_frames, start_frame = start_frame)
- keep_frames_parsed, error = parse_keep_frames(keep_frames, video_length)
- if len(error) > 0:
- raise gr.Error(f"invalid keep frames {keep_frames}")
- if repeat_no == 0:
- image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
- src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
- [video_mask_copy ],
- [image_refs_copy],
- video_length, image_size = image_size, device ="cpu",
- original_video= "O" in video_prompt_type,
- keep_frames=keep_frames_parsed,
- start_frame = start_frame,
- pre_src_video = [pre_video_guide]
- )
- if repeat_no == 0 and src_video != None and len(src_video) > 0:
- image_size = src_video[0].shape[-2:]
-
- else:
- src_video, src_mask, src_ref_images = None, None, None
-
-
repeat_no +=1
gen["repeat_no"] = repeat_no
- prompts_max = gen["prompts_max"]
- status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
+ src_video, src_mask, src_ref_images = None, None, None
+ prefix_video = None
+ prefix_video_frames_count = 0
+ frames_already_processed = None
+ pre_video_guide = None
+ window_no = 0
+ extra_windows = 0
+ guide_start_frame = 0
+ video_length = first_window_video_length
+ while not abort:
+ if sliding_window:
+ prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
+ extra_windows += gen.get("extra_windows",0)
+ if extra_windows > 0:
+ video_length = sliding_window_size
+ gen["extra_windows"] = 0
+ total_windows = initial_total_windows + extra_windows
+ gen["total_windows"] = total_windows
+ if window_no >= total_windows:
+ break
+ window_no += 1
+ gen["window_no"] = window_no
+
+ if diffusion_forcing:
+ if video_source != None and len(video_source) > 0 and window_no == 1:
+ keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
+ prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
+ prefix_video = prefix_video .permute(3, 0, 1, 2)
+ prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w
+ prefix_video_frames_count = prefix_video.shape[1]
+ pre_video_guide = prefix_video[:, -reuse_frames:]
+
+ if vace:
+ # video_prompt_type = video_prompt_type +"G"
+ image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
+ video_guide_copy = video_guide
+ video_mask_copy = video_mask
+ if any(process in video_prompt_type for process in ("P", "D", "G")) :
+ prompts_max = gen["prompts_max"]
+
+ status = get_latest_status(state)
+
+ preprocess_type = None
+ if "P" in video_prompt_type :
+ progress_args = [0, status + " - Extracting Open Pose Information"]
+ preprocess_type = "pose"
+ elif "D" in video_prompt_type :
+ progress_args = [0, status + " - Extracting Depth Information"]
+ preprocess_type = "depth"
+ elif "G" in video_prompt_type :
+ progress_args = [0, status + " - Extracting Gray Level Information"]
+ preprocess_type = "gray"
+
+ if preprocess_type != None :
+ send_cmd("progress", progress_args)
+ video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length if window_no == 1 else video_length - reuse_frames, start_frame = guide_start_frame, target_fps = fps)
+ keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate)
+ if len(error) > 0:
+ raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
+ keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + video_length]
+ if window_no == 1:
+ image_size = VACE_SIZE_CONFIGS[resolution_reformated] # default frame dimensions until it is set by video_src (if there is any)
+ src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy],
+ [video_mask_copy ],
+ [image_refs_copy],
+ video_length, image_size = image_size, device ="cpu",
+ original_video= "O" in video_prompt_type,
+ keep_frames=keep_frames_parsed,
+ start_frame = guide_start_frame,
+ pre_src_video = [pre_video_guide]
+ )
+ if window_no == 1 and src_video != None and len(src_video) > 0:
+ image_size = src_video[0].shape[-2:]
+ prompts_max = gen["prompts_max"]
+ status = get_latest_status(state)
- gen["progress_status"] = status
- gen["progress_phase"] = (" - Encoding Prompt", -1 )
- callback = build_callback(state, trans, progress, status, num_inference_steps)
- progress_args = [0, status + " - Encoding Prompt"]
- progress(*progress_args )
- gen["progress_args"] = progress_args
+ gen["progress_status"] = status
+ gen["progress_phase"] = (" - Encoding Prompt", -1 )
+ callback = build_callback(state, trans, send_cmd, status, num_inference_steps)
+ progress_args = [0, status + " - Encoding Prompt"]
+ send_cmd("progress", progress_args)
- try:
- start_time = time.time()
+ samples = torch.empty( (1,2)) #for testing
+ # if False:
+
+ try:
+ if trans.enable_teacache:
+ trans.teacache_counter = 0
+ trans.num_steps = num_inference_steps
+ trans.teacache_skipped_steps = 0
+ trans.previous_residual_uncond = None
+ trans.previous_residual_cond = None
+
+ if image2video:
+ samples = wan_model.generate(
+ prompt,
+ image_start,
+ image_end if image_end != None else None,
+ frame_num=(video_length // 4)* 4 + 1,
+ max_area=MAX_AREA_CONFIGS[resolution_reformated],
+ shift=flow_shift,
+ sampling_steps=num_inference_steps,
+ guide_scale=guidance_scale,
+ n_prompt=negative_prompt,
+ seed=seed,
+ offload_model=False,
+ callback=callback,
+ enable_RIFLEx = enable_RIFLEx,
+ VAE_tile_size = VAE_tile_size,
+ joint_pass = joint_pass,
+ slg_layers = slg_layers,
+ slg_start = slg_start_perc/100,
+ slg_end = slg_end_perc/100,
+ cfg_star_switch = cfg_star_switch,
+ cfg_zero_step = cfg_zero_step,
+ add_frames_for_end_image = "image2video" in model_filename
+ )
+ elif diffusion_forcing:
+ samples = wan_model.generate(
+ prompt = prompt,
+ negative_prompt = negative_prompt,
+ image = image_start,
+ input_video= pre_video_guide,
+ height = height,
+ width = width,
+ seed = seed,
+ num_frames = (video_length // 4)* 4 + 1, #377
+ num_inference_steps = num_inference_steps,
+ shift = flow_shift,
+ guidance_scale= guidance_scale,
+ callback= callback,
+ VAE_tile_size = VAE_tile_size,
+ joint_pass = joint_pass,
+ addnoise_condition = 20,
+ ar_step = model_mode, #5
+ causal_block_size = 5,
+ causal_attention = True,
+ fps = fps,
+ )
+ else:
+ samples = wan_model.generate(
+ prompt,
+ input_frames = src_video,
+ input_ref_images= src_ref_images,
+ input_masks = src_mask,
+ source_video= source_video,
+ target_camera= target_camera,
+ frame_num=(video_length // 4)* 4 + 1,
+ size=(width, height),
+ shift=flow_shift,
+ sampling_steps=num_inference_steps,
+ guide_scale=guidance_scale,
+ n_prompt=negative_prompt,
+ seed=seed,
+ offload_model=False,
+ callback=callback,
+ enable_RIFLEx = enable_RIFLEx,
+ VAE_tile_size = VAE_tile_size,
+ joint_pass = joint_pass,
+ slg_layers = slg_layers,
+ slg_start = slg_start_perc/100,
+ slg_end = slg_end_perc/100,
+ cfg_star_switch = cfg_star_switch,
+ cfg_zero_step = cfg_zero_step,
+ )
+ except Exception as e:
+ if temp_filename!= None and os.path.isfile(temp_filename):
+ os.remove(temp_filename)
+ offload.last_offload_obj.unload_all()
+ offload.unload_loras_from_model(trans)
+ # if compile:
+ # cache_size = torch._dynamo.config.cache_size_limit
+ # torch.compiler.reset()
+ # torch._dynamo.config.cache_size_limit = cache_size
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ s = str(e)
+ keyword_list = {"CUDA out of memory" : "VRAM", "Tried to allocate":"VRAM", "CUDA error: out of memory": "RAM", "CUDA error: too many resources requested": "RAM"}
+ crash_type = ""
+ for keyword, tp in keyword_list.items():
+ if keyword in s:
+ crash_type = tp
+ break
+ state["prompt"] = ""
+ if crash_type == "VRAM":
+ new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
+ elif crash_type == "RAM":
+ new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient RAM and / or Reserved RAM allocation should be reduced using 'perc_reserved_mem_max' or using a different Profile."
+ else:
+ new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
+ tb = traceback.format_exc().split('\n')[:-1]
+ print('\n'.join(tb))
+ send_cmd("error", new_error)
+ return
if trans.enable_teacache:
- trans.teacache_counter = 0
- trans.num_steps = num_inference_steps
- trans.teacache_skipped_steps = 0
+ print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
trans.previous_residual_uncond = None
trans.previous_residual_cond = None
- video_no += 1
- if image2video:
- samples = wan_model.generate(
- prompt,
- image_start,
- image_end if image_end != None else None,
- frame_num=(video_length // 4)* 4 + 1,
- max_area=MAX_AREA_CONFIGS[resolution_reformated],
- shift=flow_shift,
- sampling_steps=num_inference_steps,
- guide_scale=guidance_scale,
- n_prompt=negative_prompt,
- seed=seed,
- offload_model=False,
- callback=callback,
- enable_RIFLEx = enable_RIFLEx,
- VAE_tile_size = VAE_tile_size,
- joint_pass = joint_pass,
- slg_layers = slg_layers,
- slg_start = slg_start_perc/100,
- slg_end = slg_end_perc/100,
- cfg_star_switch = cfg_star_switch,
- cfg_zero_step = cfg_zero_step,
- add_frames_for_end_image = "image2video" in model_filename
- )
- else:
- samples = wan_model.generate(
- prompt,
- input_frames = src_video,
- input_ref_images= src_ref_images,
- input_masks = src_mask,
- source_video= source_video,
- target_camera= target_camera,
- frame_num=(video_length // 4)* 4 + 1,
- size=(width, height),
- shift=flow_shift,
- sampling_steps=num_inference_steps,
- guide_scale=guidance_scale,
- n_prompt=negative_prompt,
- seed=seed,
- offload_model=False,
- callback=callback,
- enable_RIFLEx = enable_RIFLEx,
- VAE_tile_size = VAE_tile_size,
- joint_pass = joint_pass,
- slg_layers = slg_layers,
- slg_start = slg_start_perc/100,
- slg_end = slg_end_perc/100,
- cfg_star_switch = cfg_star_switch,
- cfg_zero_step = cfg_zero_step,
- )
- # samples = torch.empty( (1,2)) #for testing
- except Exception as e:
- if temp_filename!= None and os.path.isfile(temp_filename):
- os.remove(temp_filename)
+ if samples != None:
+ samples = samples.to("cpu")
offload.last_offload_obj.unload_all()
- offload.unload_loras_from_model(trans)
- # if compile:
- # cache_size = torch._dynamo.config.cache_size_limit
- # torch.compiler.reset()
- # torch._dynamo.config.cache_size_limit = cache_size
-
gc.collect()
torch.cuda.empty_cache()
- s = str(e)
- keyword_list = {"CUDA out of memory" : "VRAM", "Tried to allocate":"VRAM", "CUDA error: out of memory": "RAM", "CUDA error: too many resources requested": "RAM"}
- crash_type = ""
- for keyword, tp in keyword_list.items():
- if keyword in s:
- crash_type = tp
- break
- state["prompt"] = ""
- if crash_type == "VRAM":
- new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
- elif crash_type == "RAM":
- new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient RAM and / or Reserved RAM allocation should be reduced using 'perc_reserved_mem_max' or using a different Profile."
+
+ if samples == None:
+ abort = True
+ state["prompt"] = ""
else:
- new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
- tb = traceback.format_exc().split('\n')[:-1]
- print('\n'.join(tb))
- raise gr.Error(new_error, print_exception= False)
+ sample = samples.cpu()
+ if True: # for testing
+ torch.save(sample, "output.pt")
+ else:
+ sample =torch.load("output.pt")
- finally:
- pass
- # with tracker_lock:
- # if task_id in progress_tracker:
- # del progress_tracker[task_id]
-
- if trans.enable_teacache:
- print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
- trans.previous_residual_uncond = None
- trans.previous_residual_cond = None
-
- if samples != None:
- samples = samples.to("cpu")
- offload.last_offload_obj.unload_all()
- gc.collect()
- torch.cuda.empty_cache()
-
- if samples == None:
- end_time = time.time()
- abort = True
- state["prompt"] = ""
- # yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
- else:
- sample = samples.cpu()
- if sliding_window :
- start_frame += video_length
- if discard_last_frames > 0:
- sample = sample[: , :-discard_last_frames]
- start_frame -= discard_last_frames
- pre_video_guide = sample[:, -reuse_frames:]
- if repeat_no > 1:
+ if sliding_window :
+ guide_start_frame += video_length
+ if discard_last_frames > 0:
+ sample = sample[: , :-discard_last_frames]
+ guide_start_frame -= discard_last_frames
+ pre_video_guide = sample[:, -reuse_frames:]
+ if prefix_video != None:
+ sample = torch.cat([ prefix_video[:, :-reuse_frames], sample], dim = 1)
+ prefix_video = None
+ if sliding_window and window_no > 1:
sample = sample[: , reuse_frames:]
- start_frame -= reuse_frames
+ guide_start_frame -= reuse_frames
- time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
- if os.name == 'nt':
- file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
- else:
- file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
- video_path = os.path.join(save_path, file_name)
- # if False: # for testing
- # torch.save(sample, "output.pt")
- # else:
- # sample =torch.load("output.pt")
- exp = 0
- fps = 16
-
- if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0:
- progress_args = [(num_inference_steps , num_inference_steps) , status + " - Upsampling" , num_inference_steps]
- progress(*progress_args )
- gen["progress_args"] = progress_args
-
- if temporal_upsampling == "rife2":
- exp = 1
- elif temporal_upsampling == "rife4":
- exp = 2
-
- if exp > 0:
- from rife.inference import temporal_interpolation
- if sliding_window and repeat_no > 1:
- sample = torch.cat([frames_already_processed[:, -2:-1], sample], dim=1)
- sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
- sample = sample[:, 1:]
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
+ if os.name == 'nt':
+ file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
else:
- sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
+ file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
+ video_path = os.path.join(save_path, file_name)
+ exp = 0
- fps = fps * 2**exp
+ if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0:
+ progress_args = [(num_inference_steps , num_inference_steps) , status + " - Upsampling" , num_inference_steps]
+ send_cmd("progress", progress_args)
- if len(spatial_upsampling) > 0:
- from wan.utils.utils import resize_lanczos # need multithreading or to do lanczos with cuda
- if spatial_upsampling == "lanczos1.5":
- scale = 1.5
- else:
- scale = 2
- sample = (sample + 1) / 2
- h, w = sample.shape[-2:]
- h *= scale
- w *= scale
- h = int(h)
- w = int(w)
- new_frames =[]
- for i in range( sample.shape[1] ):
- frame = sample[:, i]
- frame = resize_lanczos(frame, h, w)
- frame = frame.unsqueeze(1)
- new_frames.append(frame)
- sample = torch.cat(new_frames, dim=1)
- new_frames = None
- sample = sample * 2 - 1
+ if temporal_upsampling == "rife2":
+ exp = 1
+ elif temporal_upsampling == "rife4":
+ exp = 2
+
+ if exp > 0:
+ from rife.inference import temporal_interpolation
+ if sliding_window and window_no > 1:
+ sample = torch.cat([frames_already_processed[:, -2:-1], sample], dim=1)
+ sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
+ sample = sample[:, 1:]
+ else:
+ sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
- if sliding_window :
- if repeat_no == 1:
+ fps = fps * 2**exp
+
+ if len(spatial_upsampling) > 0:
+ from wan.utils.utils import resize_lanczos # need multithreading or to do lanczos with cuda
+ if spatial_upsampling == "lanczos1.5":
+ scale = 1.5
+ else:
+ scale = 2
+ sample = (sample + 1) / 2
+ h, w = sample.shape[-2:]
+ h *= scale
+ w *= scale
+ h = int(h)
+ w = int(w)
+ new_frames =[]
+ for i in range( sample.shape[1] ):
+ frame = sample[:, i]
+ frame = resize_lanczos(frame, h, w)
+ frame = frame.unsqueeze(1)
+ new_frames.append(frame)
+ sample = torch.cat(new_frames, dim=1)
+ new_frames = None
+ sample = sample * 2 - 1
+
+ if sliding_window :
+ if frames_already_processed == None:
+ frames_already_processed = sample
+ else:
+ sample = torch.cat([frames_already_processed, sample], dim=1)
frames_already_processed = sample
+
+ cache_video(
+ tensor=sample[None],
+ save_file=video_path,
+ fps=fps,
+ nrow=1,
+ normalize=True,
+ value_range=(-1, 1))
+
+ inputs = get_function_arguments(generate_video, locals())
+ inputs.pop("send_cmd")
+ configs = prepare_inputs_dict("metadata", inputs)
+
+ metadata_choice = server_config.get("metadata_type","metadata")
+ if metadata_choice == "json":
+ with open(video_path.replace('.mp4', '.json'), 'w') as f:
+ json.dump(configs, f, indent=4)
+ elif metadata_choice == "metadata":
+ from mutagen.mp4 import MP4
+ file = MP4(video_path)
+ file.tags['©cmt'] = [json.dumps(configs)]
+ file.save()
+
+ print(f"New video saved to Path: "+video_path)
+ file_list.append(video_path)
+ state['update_gallery'] = True
+ send_cmd("output")
+ if sliding_window :
+ if max_frames_to_generate > 0 and extra_windows == 0:
+ current_length = sample.shape[1]
+ if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
+ break
+ video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
else:
- sample = torch.cat([frames_already_processed, sample], dim=1)
- frames_already_processed = sample
+ break
- cache_video(
- tensor=sample[None],
- save_file=video_path,
- fps=fps,
- nrow=1,
- normalize=True,
- value_range=(-1, 1))
-
- inputs = get_function_arguments(generate_video, locals())
- inputs.pop("progress")
- configs = prepare_inputs_dict("metadata", inputs)
-
- metadata_choice = server_config.get("metadata_choice","metadata")
- if metadata_choice == "json":
- with open(video_path.replace('.mp4', '.json'), 'w') as f:
- json.dump(configs, f, indent=4)
- elif metadata_choice == "metadata":
- from mutagen.mp4 import MP4
- file = MP4(video_path)
- file.tags['©cmt'] = [json.dumps(configs)]
- file.save()
-
- print(f"New video saved to Path: "+video_path)
- file_list.append(video_path)
- state['update_gallery'] = True
- if not sliding_window:
- seed += 1
- yield status
+ seed += 1
if temp_filename!= None and os.path.isfile(temp_filename):
os.remove(temp_filename)
@@ -2795,12 +2862,73 @@ def prepare_generate_video(state):
else:
return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True)
+def generate_preview(latents):
+ import einops
+
+ latent_channels = 16
+ latent_dimensions = 3
+ latents = latents.unsqueeze(0)
+ latent_rgb_factors = [
+ [-0.1299, -0.1692, 0.2932],
+ [ 0.0671, 0.0406, 0.0442],
+ [ 0.3568, 0.2548, 0.1747],
+ [ 0.0372, 0.2344, 0.1420],
+ [ 0.0313, 0.0189, -0.0328],
+ [ 0.0296, -0.0956, -0.0665],
+ [-0.3477, -0.4059, -0.2925],
+ [ 0.0166, 0.1902, 0.1975],
+ [-0.0412, 0.0267, -0.1364],
+ [-0.1293, 0.0740, 0.1636],
+ [ 0.0680, 0.3019, 0.1128],
+ [ 0.0032, 0.0581, 0.0639],
+ [-0.1251, 0.0927, 0.1699],
+ [ 0.0060, -0.0633, 0.0005],
+ [ 0.3477, 0.2275, 0.2950],
+ [ 0.1984, 0.0913, 0.1861]
+ ]
+
+ # credits for the rgb factors to ComfyUI ?
+
+ latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
+
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
+ nb_latents = latents.shape[2]
+ latents_to_preview = 4
+ latents_to_preview = min(nb_latents, latents_to_preview)
+ skip_latent = nb_latents / latents_to_preview
+ latent_no = 0
+ selected_latents = []
+ while latent_no < nb_latents:
+ selected_latents.append( latents[:, : , int(latent_no): int(latent_no)+1])
+ latent_no += skip_latent
+
+ latents = torch.cat(selected_latents, dim = 2)
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
+
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
+ images = images.clamp(0.0, 1.0)
+
+
+ images = (images * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
+ images = einops.rearrange(images, 'b c t h w -> (b h) (t w) c')
+ h, w, _ = images.shape
+ scale = 200 / h
+ images= Image.fromarray(images)
+ images = images.resize(( int(w*scale),int(h*scale)), resample=Image.Resampling.BILINEAR)
+
+ return images
+
+
+def process_tasks(state):
+ from wan.utils.thread_utils import AsyncStream, async_run
-def process_tasks(state, progress=gr.Progress()):
gen = get_gen_info(state)
queue = gen.get("queue", [])
+ progress = None
if len(queue) == 0:
+ gen["status_display"] = False
return
gen = get_gen_info(state)
clear_file_list = server_config.get("clear_file_list", 0)
@@ -2824,8 +2952,9 @@ def process_tasks(state, progress=gr.Progress()):
gen_in_progress = True
gen["in_progress"] = True
- yield "Generating Video"
-
+ gen["preview"] = None
+ gen["status"] = "Generating Video"
+ yield time.time(), time.time()
prompt_no = 0
while len(queue) > 0:
prompt_no += 1
@@ -2833,27 +2962,57 @@ def process_tasks(state, progress=gr.Progress()):
task = queue[0]
task_id = task["id"]
params = task['params']
- iterator = iter(generate_video(task_id, progress, **params))
- while True:
+
+ com_stream = AsyncStream()
+ send_cmd = com_stream.output_queue.push
+ def generate_video_error_handler():
try:
- ok = False
- status = next(iterator, "#")
- ok = True
- if status == "#":
- break
+ generate_video(task_id, send_cmd, **params)
except Exception as e:
- _ , exc_value, exc_traceback = sys.exc_info()
- raise exc_value.with_traceback(exc_traceback)
+ tb = traceback.format_exc().split('\n')[:-1]
+ print('\n'.join(tb))
+ send_cmd("error",str(e))
finally:
- if not ok:
- queue.clear()
- gen["prompts_max"] = 0
- gen["prompt"] = ""
- yield status
+ send_cmd("exit", None)
+
+
+ async_run(generate_video_error_handler)
+
+ while True:
+ cmd, data = com_stream.output_queue.next()
+ if cmd == "exit":
+ break
+ elif cmd == "info":
+ gr.Info(data)
+ elif cmd == "error":
+ queue.clear()
+ gen["prompts_max"] = 0
+ gen["prompt"] = ""
+ gen["status_display"] = False
+
+ raise gr.Error(data, print_exception= False)
+ elif cmd == "status":
+ gen["status"] = data
+ elif cmd == "output":
+ gen["preview"] = None
+ yield time.time() , time.time()
+ elif cmd == "progress":
+ gen["progress_args"] = data
+ # progress(*data)
+ elif cmd == "preview":
+ preview= None if data== None else generate_preview(data)
+ gen["preview"] = preview
+ yield time.time() , gr.Text()
+ else:
+ raise Exception(f"unknown command {cmd}")
+
abort = gen.get("abort", False)
if abort:
gen["abort"] = False
- yield "Video Generation Aborted"
+ status = "Video Generation Aborted", "Video Generation Aborted"
+ yield gr.Text(), gr.Text()
+ gen["status"] = status
+
queue[:] = [item for item in queue if item['id'] != task['id']]
update_global_queue_ref(queue)
@@ -2861,25 +3020,29 @@ def process_tasks(state, progress=gr.Progress()):
gen["prompt"] = ""
end_time = time.time()
if abort:
- yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
+ status = f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
else:
- yield f"Total Generation Time: {end_time-start_time:.1f}s"
+ status = f"Total Generation Time: {end_time-start_time:.1f}s"
+ gen["status"] = status
+ gen["status_display"] = False
-def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, sliding_window):
- item = "Sliding Window" if sliding_window else "Sample"
+def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max, window_no, total_windows):
if prompts_max == 1:
if repeat_max == 1:
- return "Video"
+ status = "Video"
else:
- return f"{item} {repeat_no}/{repeat_max}"
+ status = f"Sample {repeat_no}/{repeat_max}"
else:
if repeat_max == 1:
- return f"Prompt {prompt_no}/{prompts_max}"
+ status = f"Prompt {prompt_no}/{prompts_max}"
else:
- return f"Prompt {prompt_no}/{prompts_max}, {item} {repeat_no}/{repeat_max}"
+ status = f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
+ if total_windows > 1:
+ status += f", Sliding Window {window_no}/{total_windows}"
+ return status
refresh_id = 0
@@ -2888,15 +3051,22 @@ def get_new_refresh_id():
refresh_id += 1
return refresh_id
-def update_status(state):
+def get_latest_status(state):
gen = get_gen_info(state)
prompt_no = gen["prompt_no"]
prompts_max = gen.get("prompts_max",0)
total_generation = gen.get("total_generation", 1)
repeat_no = gen["repeat_no"]
- sliding_window = gen["sliding_window"]
- status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, sliding_window)
- gen["progress_status"] = status
+ total_generation += gen.get("extra_orders", 0)
+ total_windows = gen.get("total_windows", 0)
+ total_windows += gen.get("extra_windows", 0)
+ window_no = gen.get("window_no", 0)
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, window_no, total_windows)
+ return status
+
+def update_status(state):
+ gen = get_gen_info(state)
+ gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
@@ -2908,19 +3078,28 @@ def one_more_sample(state):
in_progress = gen.get("in_progress", False)
if not in_progress :
return state
- prompt_no = gen["prompt_no"]
- prompts_max = gen.get("prompts_max",0)
- total_generation = gen["total_generation"] + extra_orders
- repeat_no = gen["repeat_no"]
- status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation, gen.get("sliding_window",False))
-
-
- gen["progress_status"] = status
+ total_generation = gen.get("total_generation", 0) + extra_orders
+ gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
return state
+def one_more_window(state):
+ gen = get_gen_info(state)
+ extra_windows = gen.get("extra_windows", 0)
+ extra_windows += 1
+ gen["extra_windows"]= extra_windows
+ in_progress = gen.get("in_progress", False)
+ if not in_progress :
+ return state
+ total_windows = gen.get("total_windows", 0) + extra_windows
+ gen["progress_status"] = get_latest_status(state)
+ gen["refresh"] = get_new_refresh_id()
+ gr.Info(f"An extra window generation is planned for a total of {total_windows} videos for this sample")
+
+ return state
+
def get_new_preset_msg(advanced = True):
if advanced:
return "Enter here a Name for a Lora Preset or Choose one in the List"
@@ -3234,11 +3413,21 @@ def prepare_inputs_dict(target, inputs ):
inputs.pop("image_prompt_type")
- if not "recam" in model_filename:
- inputs.pop("camera_type")
+ if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
+ inputs.pop("model_mode")
if not "Vace" in model_filename:
- unsaved_params = ["video_prompt_type", "keep_frames", "remove_background_image_ref", "sliding_window_repeat", "sliding_window_overlap", "sliding_window_discard_last_frames"]
+ unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
+ for k in unsaved_params:
+ inputs.pop(k)
+
+ if not "diffusion_forcing" in model_filename:
+ unsaved_params = ["keep_frames_video_source"]
+ for k in unsaved_params:
+ inputs.pop(k)
+
+ if not "Vace" in model_filename or "diffusion_forcing" in model_filename:
+ unsaved_params = [ "sliding_window_size", "sliding_window_overlap", "sliding_window_discard_last_frames"]
for k in unsaved_params:
inputs.pop(k)
@@ -3276,14 +3465,15 @@ def save_inputs(
image_prompt_type,
image_start,
image_end,
- video_prompt_type,
+ model_mode,
+ video_source,
+ keep_frames_video_source,
+ video_prompt_type,
image_refs,
video_guide,
+ keep_frames_video_guide,
video_mask,
- camera_type,
- video_source,
- keep_frames,
- sliding_window_repeat,
+ sliding_window_size,
sliding_window_overlap,
sliding_window_discard_last_frames,
remove_background_image_ref,
@@ -3344,10 +3534,7 @@ def download_loras():
return
def refresh_image_prompt_type(state, image_prompt_type):
- if args.multiple_images:
- return gr.Gallery(visible = "S" in image_prompt_type ), gr.Gallery(visible = "E" in image_prompt_type )
- else:
- return gr.Image(visible = "S" in image_prompt_type ), gr.Image(visible = "E" in image_prompt_type )
+ return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = "V" in image_prompt_type )
def refresh_video_prompt_type(state, video_prompt_type):
return gr.Gallery(visible = "I" in video_prompt_type), gr.Video(visible= "V" in video_prompt_type),gr.Video(visible= "M" in video_prompt_type ), gr.Text(visible= "V" in video_prompt_type) , gr.Checkbox(visible= "I" in video_prompt_type)
@@ -3482,8 +3669,22 @@ def refresh_video_prompt_video_guide_trigger(video_prompt_type, video_prompt_typ
return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= "V" in video_prompt_type ), gr.update(visible= "M" in video_prompt_type) , gr.update(visible= "V" in video_prompt_type )
-
-def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
+def refresh_preview(state):
+ gen = get_gen_info(state)
+ preview = gen.get("preview", None)
+ return preview
+
+def init_process_queue_if_any(state):
+ gen = get_gen_info(state)
+ if bool(gen.get("queue",[])):
+ state["validate_success"] = 1
+ return gr.Button(visible=False), gr.Button(visible=True), gr.Column(visible=True)
+ else:
+ return gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False)
+
+
+
+def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None, main = None):
global inputs_names #, advanced
if update_form:
@@ -3572,45 +3773,69 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if not update_form:
state = gr.State(state_dict)
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
- with gr.Column(visible= test_class_i2v(model_filename) ) as image_prompt_column:
- image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
- image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, scale= 3)
-
- if args.multiple_images:
+ diffusion_forcing = "diffusion_forcing" in model_filename
+ recammaster = "recam" in model_filename
+ vace = "Vace" in model_filename
+ with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
+ if diffusion_forcing:
+ image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
+ image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3)
+ # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
image_start = gr.Gallery(
label="Images as starting points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
+ image_end = gr.Gallery(visible=False)
+ video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
+ model_mode = gr.Dropdown(
+ choices=[
+ ("Synchronous", 0),
+ ("Asynchronous (better quality but around 50% extra steps added)", 5),
+ ],
+ value=ui_defaults.get("model_mode", 0),
+ label="Generation Type", scale = 3,
+ visible= True
+ )
+ keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" )
+ elif recammaster:
+ image_prompt_type = gr.Radio(visible= False)
+ image_start = gr.Gallery(visible = False)
+ image_end = gr.Gallery(visible=False)
+ video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),)
+ model_mode = gr.Dropdown(
+ choices=[
+ ("Pan Right", 1),
+ ("Pan Left", 2),
+ ("Tilt Up", 3),
+ ("Tilt Down", 4),
+ ("Zoom In", 5),
+ ("Zoom Out", 6),
+ ("Translate Up (with rotation)", 7),
+ ("Translate Down (with rotation)", 8),
+ ("Arc Left (with rotation)", 9),
+ ("Arc Right (with rotation)", 10),
+ ],
+ value=ui_defaults.get("model_mode", 1),
+ label="Camera Movement Type", scale = 3,
+ visible= True
+ )
+ keep_frames_video_source = gr.Text(visible=False)
else:
- image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
+ image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
+ image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3)
+
+ image_start = gr.Gallery(
+ label="Images as starting points for new videos", type ="pil", #file_types= "image",
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
- if args.multiple_images:
image_end = gr.Gallery(
label="Images as ending points for new videos", type ="pil", #file_types= "image",
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
- else:
- image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
- with gr.Column(visible= "recam" in model_filename ) as recam_column:
- camera_type = gr.Dropdown(
- choices=[
- ("Pan Right", 1),
- ("Pan Left", 2),
- ("Tilt Up", 3),
- ("Tilt Down", 4),
- ("Zoom In", 5),
- ("Zoom Out", 6),
- ("Translate Up (with rotation)", 7),
- ("Translate Down (with rotation)", 8),
- ("Arc Left (with rotation)", 9),
- ("Arc Right (with rotation)", 10),
- ],
- value=ui_defaults.get("camera_type", 1),
- label="Camera Movement Type", scale = 3
- )
- video_source = gr.Video(label= "Video Source", value= ui_defaults.get("video_source", None),)
+ video_source = gr.Video(visible=False)
+ model_mode = gr.Dropdown(visible=False)
+ keep_frames_video_source = gr.Text(visible=False)
-
- with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
+ with gr.Column(visible= vace ) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
with gr.Row():
@@ -3639,9 +3864,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
# video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
-
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
- keep_frames = gr.Text(value=ui_defaults.get("keep_frames","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
+ keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
image_refs = gr.Gallery( label ="Reference Images",
type ="pil", show_label= True,
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
@@ -3716,11 +3940,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Resolution"
)
with gr.Row():
- if "recam" in model_filename:
+ if recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s), locked", interactive= False)
+ elif diffusion_forcing:
+ video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 97), step=20, label="Number of frames (24 = 1s)", interactive= True)
+ elif vace:
+ video_length = gr.Slider(17, 737, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
else:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", 81), step=4, label="Number of frames (16 = 1s)", interactive= True)
-
+ with gr.Row():
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps")
@@ -3737,7 +3965,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices=[
("Generate every combination of images and texts", 0),
("Match images and text prompts", 1),
- ], visible= args.multiple_images, label= "Multiple Images as Texts Prompts"
+ ], visible= True, label= "Multiple Images as Texts Prompts"
)
with gr.Row():
guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance_scale",5), step=0.5, label="Guidance Scale", visible=True)
@@ -3848,14 +4076,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
- with gr.Tab("Sliding Window", visible= "Vace" in model_filename ) as sliding_window_tab:
+ with gr.Tab("Sliding Window", visible= "Vace" in model_filename or "diffusion_forcing" in model_filename) as sliding_window_tab:
with gr.Column():
gr.Markdown("A Sliding Window allows you to generate video with a duration not limited by the Model")
+ gr.Markdown("It is automatically turned on if the number of frames to generate is higher than the Window Size")
+ if diffusion_forcing:
+ sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size")
+ sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
+ sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=1, visible = False)
+ else:
+ sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
+ sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
+ sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 16), step=1, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
- sliding_window_repeat = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_repeat", 0), step=1, label="Sliding Window Iterations (O=Disabled)")
- sliding_window_overlap = gr.Slider(1, 32, value=ui_defaults.get("sliding_window_overlap",16), step=1, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
- sliding_window_discard_last_frames = gr.Slider(1, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=1, label="Discard Last Frames of a Window (that may have bad quality)")
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
gr.Markdown("With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model")
@@ -3875,110 +4109,206 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if not update_form:
with gr.Column():
gen_status = gr.Text(interactive= False, label = "Status")
+ status_trigger = gr.Text(interactive= False, visible=False)
output = gr.Gallery( label="Generated videos", show_label=False, elem_id="gallery" , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
+ output_trigger = gr.Text(interactive= False, visible=False)
+
generate_btn = gr.Button("Generate")
add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
with gr.Column(visible= False) as current_gen_column:
+ with gr.Accordion("Preview", open=False) as queue_accordion:
+ preview = gr.Image(label="Preview", height=200, show_label= False)
+ preview_trigger = gr.Text(visible= False)
+ gen_info = gr.HTML(visible=False, min_height=1)
with gr.Row():
- gen_info = gr.HTML(visible=False, min_height=1)
- with gr.Row():
- onemore_btn = gr.Button("One More Sample Please !")
+ onemoresample_btn = gr.Button("One More Sample Please !")
+ onemorewindow_btn = gr.Button("Extend this Sample Please !", visible = False)
abort_btn = gr.Button("Abort")
with gr.Accordion("Queue Management", open=False) as queue_accordion:
- queue_df = gr.DataFrame(
- headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
- datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
- column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "34"],
- interactive=False,
- col_count=(9, "fixed"),
- wrap=True,
- value=[],
- line_breaks= True,
- visible= True,
- elem_id="queue_df"
- )
- with gr.Row():
- queue_zip_base64_output = gr.Text(visible=False)
- save_queue_btn = gr.DownloadButton("Save Queue", size="sm")
- load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm")
- clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop")
- quit_button = gr.Button("Save and Quit", size="sm", variant="secondary")
- with gr.Row(visible=False) as quit_confirmation_row:
- confirm_quit_button = gr.Button("Confirm", elem_id="comfirm_quit_btn_hidden", size="sm", variant="stop")
- cancel_quit_button = gr.Button("Cancel", size="sm", variant="secondary")
- hidden_force_quit_trigger = gr.Button("force_quit", visible=False, elem_id="force_quit_btn_hidden")
- hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num")
- single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn")
+ with gr.Row( ):
+ queue_df = gr.DataFrame(
+ headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
+ datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
+ column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "34"],
+ interactive=False,
+ col_count=(9, "fixed"),
+ wrap=True,
+ value=[],
+ line_breaks= True,
+ visible= True,
+ elem_id="queue_df",
+ max_height= 1000
- start_quit_timer_js = """
- () => {
- function findAndClickGradioButton(elemId) {
- const gradioApp = document.querySelector('gradio-app') || document;
- const button = gradioApp.querySelector(`#${elemId}`);
- if (button) { button.click(); }
- }
+ )
+ with gr.Row(visible= True):
+ queue_zip_base64_output = gr.Text(visible=False)
+ save_queue_btn = gr.DownloadButton("Save Queue", size="sm")
+ load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm")
+ clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop")
+ quit_button = gr.Button("Save and Quit", size="sm", variant="secondary")
+ with gr.Row(visible=False) as quit_confirmation_row:
+ confirm_quit_button = gr.Button("Confirm", elem_id="comfirm_quit_btn_hidden", size="sm", variant="stop")
+ cancel_quit_button = gr.Button("Cancel", size="sm", variant="secondary")
+ hidden_force_quit_trigger = gr.Button("force_quit", visible=False, elem_id="force_quit_btn_hidden")
+ hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num")
+ single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn")
- if (window.quitCountdownTimeoutId) clearTimeout(window.quitCountdownTimeoutId);
+ extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
+ prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
+ video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column,
+ if update_form:
+ locals_dict = locals()
+ gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
+ return gen_inputs
+ else:
+ target_state = gr.Text(value = "state", interactive= False, visible= False)
+ target_settings = gr.Text(value = "settings", interactive= False, visible= False)
- let js_click_count = 0;
- const max_clicks = 5;
+ image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
+ video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, video_mask, keep_frames_video_guide])
+ video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_mask])
- function countdownStep() {
- if (js_click_count < max_clicks) {
- findAndClickGradioButton('trigger_info_single_btn');
- js_click_count++;
- window.quitCountdownTimeoutId = setTimeout(countdownStep, 1000);
- } else {
- findAndClickGradioButton('force_quit_btn_hidden');
- }
- }
+ show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
+ fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
+ queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
+ save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
+ confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
+ save_lset, inputs=[state, lset_name, loras_choices, loras_multipliers, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
+ delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
+ confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
+ cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
+ apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt]).then(
+ fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
+ )
+ refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+ refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
+ output.select(select_video, state, None )
+ preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview])
- countdownStep();
- }
- """
+ def refresh_status_async(state, progress=gr.Progress()):
+ gen = get_gen_info(state)
+ gen["progress"] = progress
- cancel_quit_timer_js = """
- () => {
- if (window.quitCountdownTimeoutId) {
- clearTimeout(window.quitCountdownTimeoutId);
- window.quitCountdownTimeoutId = null;
- console.log("Quit countdown cancelled (single trigger).");
- }
- }
- """
+ while True:
+ progress_args= gen.get("progress_args", None)
+ if progress_args != None:
+ progress(*progress_args)
+ gen["progress_args"] = None
+ status= gen.get("status","")
+ if status == None or len(status) > 0:
+ yield status
+ gen["status"]= ""
+ if not gen.get("status_display", False):
+ return
+ time.sleep(0.5)
- trigger_zip_download_js = """
- (base64String) => {
- if (!base64String) {
- console.log("No base64 zip data received, skipping download.");
- return;
- }
- try {
- const byteCharacters = atob(base64String);
- const byteNumbers = new Array(byteCharacters.length);
- for (let i = 0; i < byteCharacters.length; i++) {
- byteNumbers[i] = byteCharacters.charCodeAt(i);
- }
- const byteArray = new Uint8Array(byteNumbers);
- const blob = new Blob([byteArray], { type: 'application/zip' });
+ def activate_status(state):
+ if state.get("validate_success",0) != 1:
+ return
+ gen = get_gen_info(state)
+ gen["status_display"] = True
+ return time.time()
- const url = URL.createObjectURL(blob);
- const a = document.createElement('a');
- a.style.display = 'none';
- a.href = url;
- a.download = 'queue.zip';
- document.body.appendChild(a);
- a.click();
+ status_trigger.change(refresh_status_async, inputs= [state] , outputs= [gen_status], show_progress_on= [gen_status])
- window.URL.revokeObjectURL(url);
- document.body.removeChild(a);
- console.log("Zip download triggered.");
- } catch (e) {
- console.error("Error processing base64 data or triggering download:", e);
- }
- }
- """
+ output_trigger.change(refresh_gallery,
+ inputs = [state],
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn, onemorewindow_btn])
+
+
+
+ abort_btn.click(abort_generation, [state], [ abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
+ onemoresample_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
+ onemorewindow_btn.click(fn=one_more_window,inputs=[state], outputs= [state])
+
+ inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1]
+ locals_dict = locals()
+ gen_inputs = [locals_dict[k] for k in inputs_names] + [state]
+ save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
+ save_inputs, inputs =[target_settings] + gen_inputs, outputs = [])
+
+
+ model_choice.change(fn=validate_wizard_prompt,
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=save_inputs,
+ inputs =[target_state] + gen_inputs,
+ outputs= None
+ ).then(fn= change_model,
+ inputs=[state, model_choice],
+ outputs= [header]
+ ).then(fn= fill_inputs,
+ inputs=[state],
+ outputs=gen_inputs + extra_inputs
+ ).then(fn= preload_model_when_switching,
+ inputs=[state],
+ outputs=[gen_status])
+
+ generate_btn.click(fn=validate_wizard_prompt,
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
+ outputs= [prompt]
+ ).then(fn=save_inputs,
+ inputs =[target_state] + gen_inputs,
+ outputs= None
+ ).then(fn=process_prompt_and_add_tasks,
+ inputs = [state, model_choice],
+ outputs= queue_df
+ ).then(fn=prepare_generate_video,
+ inputs= [state],
+ outputs= [generate_btn, add_to_queue_btn, current_gen_column]
+ ).then(fn=activate_status,
+ inputs= [state],
+ outputs= [status_trigger],
+ ).then(fn=process_tasks,
+ inputs= [state],
+ outputs= [preview_trigger, output_trigger],
+ ).then(finalize_generation,
+ inputs= [state],
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
+ ).then(
+ fn=lambda s: gr.Accordion(open=False) if len(get_gen_info(s).get("queue", [])) <= 1 else gr.update(),
+ inputs=[state],
+ outputs=[queue_accordion]
+ ).then(unload_model_if_needed,
+ inputs= [state],
+ outputs= []
+ )
+
+ gr.on(triggers=[load_queue_btn.upload, main.load],
+ fn=load_queue_action,
+ inputs=[load_queue_btn, state],
+ outputs=[queue_df]
+ ).then(
+ fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()),
+ inputs=[state],
+ outputs=[current_gen_column, queue_accordion]
+ ).then(
+ fn=init_process_queue_if_any,
+ inputs=[state],
+ outputs=[generate_btn, add_to_queue_btn, current_gen_column, ]
+ ).then(fn=activate_status,
+ inputs= [state],
+ outputs= [status_trigger],
+ ).then(
+ fn=process_tasks,
+ inputs=[state],
+ outputs=[preview_trigger, output_trigger],
+ trigger_mode="once"
+ ).then(
+ fn=finalize_generation_with_state,
+ inputs=[state],
+ outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state],
+ trigger_mode="always_last"
+ ).then(
+ unload_model_if_needed,
+ inputs= [state],
+ outputs= []
+ )
+
+
+ start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js = get_timer_js()
single_hidden_trigger_btn.click(
fn=show_countdown_info_from_state,
@@ -4026,39 +4356,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
js=trigger_zip_download_js
)
- should_start_flag = gr.State(False)
- load_queue_btn.upload(
- fn=load_queue_action,
- inputs=[load_queue_btn, state],
- outputs=[queue_df]
- ).then(
- fn=lambda s: (gr.update(visible=bool(get_gen_info(s).get("queue",[]))), gr.Accordion(open=True)) if bool(get_gen_info(s).get("queue",[])) else (gr.update(visible=False), gr.update()),
- inputs=[state],
- outputs=[current_gen_column, queue_accordion]
- ).then(
- fn=lambda s: (
- (gr.Button(visible=False), gr.Button(visible=True), gr.Column(visible=True), True)
- if bool(get_gen_info(s).get("queue",[]))
- else (gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False), False)
- ),
- inputs=[state],
- outputs=[generate_btn, add_to_queue_btn, current_gen_column, should_start_flag]
- ).then(
- fn=start_processing_if_needed,
- inputs=[should_start_flag, state],
- outputs=[gen_status],
- trigger_mode="once"
- ).then(
- fn=finalize_generation_with_state,
- inputs=[state],
- outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, queue_accordion, state],
- trigger_mode="always_last"
- ).then(
- unload_model_if_needed,
- inputs= [state],
- outputs= []
- )
-
clear_queue_btn.click(
fn=clear_queue_action,
inputs=[state],
@@ -4069,95 +4366,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs=[current_gen_column, queue_accordion]
)
- extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
- prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, sliding_window_tab,
- video_prompt_type_video_guide, video_prompt_type_image_refs, recam_column] # show_advanced presets_column,
- if update_form:
- locals_dict = locals()
- gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
- return gen_inputs
- else:
- target_state = gr.Text(value = "state", interactive= False, visible= False)
- target_settings = gr.Text(value = "settings", interactive= False, visible= False)
-
- image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
- video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, video_mask, keep_frames])
- video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
- video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames, video_mask])
-
- show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
- fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
- queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
- save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
- confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
- save_lset, inputs=[state, lset_name, loras_choices, loras_multipliers, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
- delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
- confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
- cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
- apply_lset_btn.click(apply_lset, inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt]).then(
- fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
- )
- refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
- refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
- output.select(select_video, state, None )
-
- gen_status.change(refresh_gallery,
- inputs = [state, gen_status],
- outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
-
-
- abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
- onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
-
- inputs_names= list(inspect.signature(save_inputs).parameters)[1:-1]
- locals_dict = locals()
- gen_inputs = [locals_dict[k] for k in inputs_names] + [state]
- save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
- save_inputs, inputs =[target_settings] + gen_inputs, outputs = [])
-
-
- model_choice.change(fn=validate_wizard_prompt,
- inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
- ).then(fn=save_inputs,
- inputs =[target_state] + gen_inputs,
- outputs= None
- ).then(fn= change_model,
- inputs=[state, model_choice],
- outputs= [header]
- ).then(fn= fill_inputs,
- inputs=[state],
- outputs=gen_inputs + extra_inputs
- ).then(fn= preload_model_when_switching,
- inputs=[state],
- outputs=[gen_status])
-
- generate_btn.click(fn=validate_wizard_prompt,
- inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
- ).then(fn=save_inputs,
- inputs =[target_state] + gen_inputs,
- outputs= None
- ).then(fn=process_prompt_and_add_tasks,
- inputs = [state, model_choice],
- outputs= queue_df
- ).then(fn=prepare_generate_video,
- inputs= [state],
- outputs= [generate_btn, add_to_queue_btn, current_gen_column]
- ).then(fn=process_tasks,
- inputs= [state],
- outputs= [gen_status],
- ).then(finalize_generation,
- inputs= [state],
- outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
- ).then(
- fn=lambda s: gr.Accordion(open=False) if len(get_gen_info(s).get("queue", [])) <= 1 else gr.update(),
- inputs=[state],
- outputs=[queue_accordion]
- ).then(unload_model_if_needed,
- inputs= [state],
- outputs= []
- )
add_to_queue_btn.click(fn=validate_wizard_prompt,
inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
@@ -4183,10 +4391,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs=[modal_container]
)
- return ( state,
- loras_choices, lset_name, state, queue_df, current_gen_column,
- gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
- gen_info, queue_accordion, video_guide, video_mask, video_prompt_video_guide_trigger
+ return ( state, loras_choices, lset_name, state,
+ video_guide, video_mask, video_prompt_video_guide_trigger
)
@@ -4437,6 +4643,77 @@ def select_tab(tab_state, evt:gr.SelectData):
tab_state["tab_no"] = new_tab_no
return gr.Tabs()
+def get_timer_js():
+ start_quit_timer_js = """
+ () => {
+ function findAndClickGradioButton(elemId) {
+ const gradioApp = document.querySelector('gradio-app') || document;
+ const button = gradioApp.querySelector(`#${elemId}`);
+ if (button) { button.click(); }
+ }
+
+ if (window.quitCountdownTimeoutId) clearTimeout(window.quitCountdownTimeoutId);
+
+ let js_click_count = 0;
+ const max_clicks = 5;
+
+ function countdownStep() {
+ if (js_click_count < max_clicks) {
+ findAndClickGradioButton('trigger_info_single_btn');
+ js_click_count++;
+ window.quitCountdownTimeoutId = setTimeout(countdownStep, 1000);
+ } else {
+ findAndClickGradioButton('force_quit_btn_hidden');
+ }
+ }
+
+ countdownStep();
+ }
+ """
+
+ cancel_quit_timer_js = """
+ () => {
+ if (window.quitCountdownTimeoutId) {
+ clearTimeout(window.quitCountdownTimeoutId);
+ window.quitCountdownTimeoutId = null;
+ console.log("Quit countdown cancelled (single trigger).");
+ }
+ }
+ """
+
+ trigger_zip_download_js = """
+ (base64String) => {
+ if (!base64String) {
+ console.log("No base64 zip data received, skipping download.");
+ return;
+ }
+ try {
+ const byteCharacters = atob(base64String);
+ const byteNumbers = new Array(byteCharacters.length);
+ for (let i = 0; i < byteCharacters.length; i++) {
+ byteNumbers[i] = byteCharacters.charCodeAt(i);
+ }
+ const byteArray = new Uint8Array(byteNumbers);
+ const blob = new Blob([byteArray], { type: 'application/zip' });
+
+ const url = URL.createObjectURL(blob);
+ const a = document.createElement('a');
+ a.style.display = 'none';
+ a.href = url;
+ a.download = 'queue.zip';
+ document.body.appendChild(a);
+ a.click();
+
+ window.URL.revokeObjectURL(url);
+ document.body.removeChild(a);
+ console.log("Zip download triggered.");
+ } catch (e) {
+ console.error("Error processing base64 data or triggering download:", e);
+ }
+ }
+ """
+ return start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js
+
def create_demo():
global vmc_event_handler
css = """
@@ -4670,6 +4947,7 @@ def create_demo():
z-index: 2;
pointer-events: none;
}
+
"""
UI_theme = server_config.get("UI_theme", "default")
UI_theme = args.theme if len(args.theme) > 0 else UI_theme
@@ -4678,8 +4956,8 @@ def create_demo():
else:
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
- with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as demo:
- gr.Markdown("