added Phantom model support

This commit is contained in:
deepbeepmeep 2025-04-28 03:18:33 +02:00
parent 31abc06545
commit 480645c85f
8 changed files with 489 additions and 370 deletions

View File

@ -10,7 +10,9 @@
## 🔥 Latest News!!
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below)
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !

View File

@ -31,6 +31,8 @@ class DTT2V:
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False,
):
self.device = torch.device(f"cuda")
self.config = config
@ -50,24 +52,22 @@ class DTT2V:
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
# model_filename = "model.safetensors"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json"
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
# offload.load_model_data(self.model, "recam.ckpt")
# self.model.cpu()
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "rt1.3B.safetensors", config_file_path="config.json")
# offload.save_model(self.model, "rtint8.safetensors", do_quantize= "config.json")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.eval().requires_grad_(False)
self.scheduler = FlowUniPCMultistepScheduler()
@ -228,11 +228,16 @@ class DTT2V:
latent_height = height // 8
latent_width = width // 8
prompt_embeds = self.text_encoder([prompt], self.device)
prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
if self._interrupt:
return None
prompt_embeds = self.text_encoder([prompt], self.device)[0]
prompt_embeds = prompt_embeds.to(self.dtype).to(self.device)
if self.do_classifier_free_guidance:
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)[0]
negative_prompt_embeds = negative_prompt_embeds.to(self.dtype).to(self.device)
if self._interrupt:
return None
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
init_timesteps = self.scheduler.timesteps
@ -305,6 +310,17 @@ class DTT2V:
del time_steps_comb
from mmgp import offload
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
kwrags = {
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self,
}
kwrags.update(i2v_extra_kwrags)
for i, timestep_i in enumerate(tqdm(step_matrix)):
offload.set_step_no_for_lora(self.model, i)
update_mask_i = step_update_mask[i]
@ -323,52 +339,45 @@ class DTT2V:
* noise_factor
)
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
kwrags = {
kwrags.update({
"x" : torch.stack([latent_model_input[0]]),
"t" : timestep,
"freqs" :freqs,
"fps" : fps_embeds,
"causal_block_size" : causal_block_size,
"causal_attention" : causal_attention,
"callback" : callback,
"pipeline" : self,
"current_step" : i,
}
kwrags.update(i2v_extra_kwrags)
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context=prompt_embeds,
context2=negative_prompt_embeds,
})
# with torch.autocast(device_type="cuda"):
if True:
if not self.do_classifier_free_guidance:
noise_pred = self.model(
context=[prompt_embeds],
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=prompt_embeds,
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=negative_prompt_embeds,
)[0]
if self._interrupt:
return None
noise_pred_cond= noise_pred_cond.to(torch.float32)
noise_pred_uncond= noise_pred_uncond.to(torch.float32)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
noise_pred= noise_pred.to(torch.float32)
else:
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
context= [prompt_embeds, negative_prompt_embeds],
**kwrags,
)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
context=[prompt_embeds],
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
context=[negative_prompt_embeds],
**kwrags,
)[0]
if self._interrupt:
return None
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond
for idx in range(valid_interval_start, valid_interval_end):
if update_mask_i[idx].item():
latents[0][:, idx] = sample_schedulers[idx].step(

View File

@ -48,47 +48,17 @@ class WanI2V:
self,
config,
checkpoint_dir,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
i2v720p= True,
model_filename ="",
text_encoder_filename="",
quantizeTransformer = False,
dtype = torch.bfloat16
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
init_on_cpu (`bool`, *optional*, defaults to True):
"""
self.device = torch.device(f"cuda")
self.config = config
self.rank = rank
self.use_usp = use_usp
self.t5_cpu = t5_cpu
self.dtype = dtype
self.VAE_dtype = VAE_dtype
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
# shard_fn = partial(shard_model, device_id=device_id)
@ -104,7 +74,7 @@ class WanI2V:
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype = VAE_dtype,
device=self.device)
self.clip = CLIPModel(
@ -118,11 +88,9 @@ class WanI2V:
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
self.model.eval().requires_grad_(False)
@ -142,7 +110,6 @@ class WanI2V:
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True,
callback = None,
enable_RIFLEx = False,
VAE_tile_size= 0,
@ -212,13 +179,13 @@ class WanI2V:
w = lat_w * self.vae_stride[2]
clip_image_size = self.clip.model.image_size
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
img = resize_lanczos(img, clip_image_size, clip_image_size)
img = img.sub_(0.5).div_(0.5).to(self.device, self.dtype)
img = img.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
if img2!= None:
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
img2 = img2.sub_(0.5).div_(0.5).to(self.device, self.dtype)
img2 = img2.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
@ -244,25 +211,19 @@ class WanI2V:
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
# self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
if self._interrupt:
return None
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
# preprocess
context = self.text_encoder([input_prompt], self.device)[0]
context_null = self.text_encoder([n_prompt], self.device)[0]
context = context.to(self.dtype)
context_null = context_null.to(self.dtype)
if self._interrupt:
return None
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
from mmgp import offload
offload.last_offload_obj.unload_all()
@ -270,23 +231,20 @@ class WanI2V:
mean2 = 0
enc= torch.concat([
img_interpolated,
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.dtype),
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.VAE_dtype),
img_interpolated2,
], dim=1).to(self.device)
else:
enc= torch.concat([
img_interpolated,
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.dtype)
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.VAE_dtype)
], dim=1).to(self.device)
img, img2, img_interpolated, img_interpolated2 = None, None, None, None
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
y = torch.concat([msk, lat_y])
lat_y = None
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
@ -317,7 +275,7 @@ class WanI2V:
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {
'context': [context[0]],
'context': [context],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
@ -326,7 +284,7 @@ class WanI2V:
}
arg_null = {
'context': context_null,
'context': [context_null],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
@ -335,8 +293,7 @@ class WanI2V:
}
arg_both= {
'context': [context[0]],
'context2': context_null,
'context': [context, context_null],
'clip_fea': clip_context,
'y': [y],
'freqs' : freqs,
@ -344,9 +301,6 @@ class WanI2V:
'callback' : callback
}
if offload_model:
torch.cuda.empty_cache()
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
@ -379,8 +333,6 @@ class WanI2V:
)[0]
if self._interrupt:
return None
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input,
t=timestep,
@ -392,8 +344,7 @@ class WanI2V:
if self._interrupt:
return None
del latent_model_input
if offload_model:
torch.cuda.empty_cache()
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch:
@ -412,9 +363,6 @@ class WanI2V:
del noise_pred_uncond
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
@ -429,29 +377,18 @@ class WanI2V:
callback(i, latent, False)
x0 = [latent.to(self.device, dtype=self.dtype)]
# x0 = [latent.to(self.device, dtype=self.dtype)]
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
x0 = [latent]
if self.rank == 0:
# x0 = [lat_y]
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
# x0 = [lat_y]
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
if any_end_frame and add_frames_for_end_image:
# video[:, -1:] = img_interpolated2
video = video[:, :-1]
else:
video = None
if any_end_frame and add_frames_for_end_image:
# video[:, -1:] = img_interpolated2
video = video[:, :-1]
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return video

View File

@ -408,6 +408,9 @@ class WanAttentionBlock(nn.Module):
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
hint = None
attention_dtype = self.self_attn.q.weight.dtype
dtype = x.dtype
if self.block_id is not None and hints is not None:
kwargs = {
"grid_sizes" : grid_sizes,
@ -434,9 +437,11 @@ class WanAttentionBlock(nn.Module):
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
x_mod += cam_emb
xlist = [x_mod]
xlist = [x_mod.to(attention_dtype)]
del x_mod
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
y = y.to(dtype)
if cam_emb != None:
y = self.projector(y)
@ -445,15 +450,18 @@ class WanAttentionBlock(nn.Module):
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
del y
y = self.norm3(x)
y = y.to(attention_dtype)
ylist= [y]
del y
x += self.cross_attn(ylist, context)
x += self.cross_attn(ylist, context).to(dtype)
y = self.norm2(x)
y = reshape_latent(y , latent_frames)
y *= 1 + e[4]
y += e[3]
y = reshape_latent(y , 1)
y = y.to(attention_dtype)
ffn = self.ffn[0]
gelu = self.ffn[1]
@ -469,7 +477,7 @@ class WanAttentionBlock(nn.Module):
y_chunk[...] = ffn2(mlp_chunk)
del mlp_chunk
y = y.view(y_shape)
y = y.to(dtype)
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
x.addcmul_(y, e[5])
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
@ -532,7 +540,6 @@ class Head(nn.Module):
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
@ -552,6 +559,7 @@ class Head(nn.Module):
x *= (1 + e[1])
x += e[0]
x = reshape_latent(x , 1)
x= x.to(self.head.weight.dtype)
x = self.head(x)
return x
@ -735,6 +743,44 @@ class WanModel(ModelMixin, ConfigMixin):
block.projector.bias = nn.Parameter(torch.zeros(dim))
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
count = 0
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
if hasattr(self, "fps_embedding"):
layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
if hasattr(self, "vace_patch_embedding"):
layer_list += [self.vace_patch_embedding]
layer_list += [self.vace_blocks[0].before_proj]
for block in self.vace_blocks:
layer_list += [block.after_proj, block.norm3]
# cam master
if hasattr(self.blocks[0], "projector"):
for block in self.blocks:
layer_list += [block.projector]
for block in self.blocks:
layer_list += [block.norm3]
for layer in layer_list:
if hasattr(layer, "weight"):
if layer.weight.dtype == dtype :
count += 1
elif force:
if hasattr(layer, "weight"):
layer.weight.data = layer.weight.data.to(dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(dtype)
count += 1
layer._lock_dtype = dtype
if count > 0:
self._lock_dtype = dtype
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
rescale_func = np.poly1d(self.coefficients)
e_list = []
@ -788,7 +834,6 @@ class WanModel(ModelMixin, ConfigMixin):
freqs = None,
pipeline = None,
current_step = 0,
context2 = None,
is_uncond=False,
max_steps = 0,
slg_layers=None,
@ -797,7 +842,10 @@ class WanModel(ModelMixin, ConfigMixin):
fps = None,
causal_block_size = 1,
causal_attention = False,
x_neg = None
):
# dtype = self.blocks[0].self_attn.q.weight.dtype
dtype = self.patch_embedding.weight.dtype
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
@ -810,9 +858,9 @@ class WanModel(ModelMixin, ConfigMixin):
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
# grid_sizes = torch.stack(
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
if x_neg !=None:
x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
grid_sizes = [ list(u.shape[2:]) for u in x]
embed_sizes = grid_sizes[0]
@ -836,57 +884,46 @@ class WanModel(ModelMixin, ConfigMixin):
x = [u.flatten(2).transpose(1, 2) for u in x]
x = x[0]
if x_neg !=None:
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
x_neg = x_neg[0]
if t.dim() == 2:
b, f = t.shape
_flag_df = True
else:
_flag_df = False
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
if self.inject_sample_info:
fps = torch.tensor(fps, dtype=torch.long, device=device)
fps_emb = self.fps_embedding(fps).float()
fps_emb = self.fps_embedding(fps).to(dtype) # float()
if _flag_df:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
else:
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if context2!=None:
context2 = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context2
]))
context = [self.text_embedding( torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))] ).unsqueeze(0) ) for u in context ]
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
if context2 != None:
context2 = torch.concat([context_clip, context2], dim=1)
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
joint_pass = context2 != None
joint_pass = len(context) > 0
x_list = [x]
if joint_pass:
x_list = [x, x.clone()]
context_list = [context, context2]
if x_neg == None:
x_list += [x.clone() for i in range(len(context) - 1) ]
else:
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
is_uncond = False
else:
x_list = [x]
context_list = [context]
del x
context_list = context
# arguments
@ -945,10 +982,7 @@ class WanModel(ModelMixin, ConfigMixin):
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
if joint_pass:
return None, None
else:
return [None]
return [None] * len(x_list)
if slg_layers is not None and block_idx in slg_layers:
if is_uncond and not joint_pass:
@ -983,10 +1017,7 @@ class WanModel(ModelMixin, ConfigMixin):
x_list[i] = self.unpatchify(x, grid_sizes)
del x
if joint_pass:
return x_list[0][0], x_list[1][0]
else:
return [u.float() for u in x_list[0]]
return [x[0].float() for x in x_list]
def unpatchify(self, x, grid_sizes):
r"""

View File

@ -752,7 +752,8 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
logging.info(f'loading {pretrained_path}')
# model.load_state_dict(
# torch.load(pretrained_path, map_location=device), assign=True)
offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
# offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
offload.load_model_data(model, pretrained_path.replace(".pth", ".safetensors"), writable_tensors= False)
return model
@ -782,20 +783,22 @@ class WanVAE:
self.model = _video_vae(
pretrained_path=vae_pth,
z_dim=z_dim,
).eval() #.requires_grad_(False).to(device)
).to(dtype).eval() #.requires_grad_(False).to(device)
def encode(self, videos, tile_size = 256, any_end_frame = False):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
original_dtype = videos[0].dtype
if tile_size > 0:
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else:
return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
def decode(self, zs, tile_size, any_end_frame = False):
if tile_size > 0:
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else:
return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]

View File

@ -52,7 +52,9 @@ class WanT2V:
model_filename = None,
text_encoder_filename = None,
quantizeTransformer = False,
dtype = torch.bfloat16
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
self.device = torch.device(f"cuda")
self.config = config
@ -71,24 +73,23 @@ class WanT2V:
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device)
logging.info(f"Creating WanModel from {model_filename}")
from mmgp import offload
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
# offload.load_model_data(self.model, "recam.ckpt")
# model_filename
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
# offload.load_model_data(self.model, "e:/vace.safetensors")
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
# self.model.to(torch.bfloat16)
# self.model.cpu()
# offload.save_model(self.model, "recam.safetensors")
if self.dtype == torch.float16 and not "fp16" in model_filename:
self.model.to(self.dtype)
# offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
if self.dtype == torch.float16:
self.vae.model.to(self.dtype)
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
# offload.save_model(self.model, "phantom_1.3B.safetensors")
self.model.eval().requires_grad_(False)
@ -252,6 +253,15 @@ class WanT2V:
return self.vae.decode(trimed_zs, tile_size= tile_size)
def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = []
for ref_image in ref_images:
ref_image = TF.to_tensor(ref_image).sub_(0.5).div_(0.5).to(self.device)
img_vae_latent = self.vae.encode([ref_image.unsqueeze(1)], tile_size= tile_size)
ref_vae_latents.append(img_vae_latent[0])
return torch.cat(ref_vae_latents, dim=1)
def generate(self,
input_prompt,
input_frames= None,
@ -320,8 +330,15 @@ class WanT2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if self._interrupt:
return None
context = self.text_encoder([input_prompt], self.device)[0]
context_null = self.text_encoder([n_prompt], self.device)[0]
context = context.to(self.dtype)
context_null = context_null.to(self.dtype)
input_ref_images_neg = None
phantom = False
if target_camera != None:
size = (source_video.shape[2], source_video.shape[1])
source_video = source_video.to(dtype=self.dtype , device=self.device)
@ -346,8 +363,12 @@ class WanT2V:
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
else:
if input_ref_images != None: # Phantom Ref images
phantom = True
input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
@ -355,8 +376,8 @@ class WanT2V:
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1])
context = [u.to(self.dtype) for u in context]
context_null = [u.to(self.dtype) for u in context_null]
if self._interrupt:
return None
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
@ -393,21 +414,15 @@ class WanT2V:
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
if target_camera != None:
recam_dict = {'cam_emb': cam_emb}
arg_c.update(recam_dict)
arg_null.update(recam_dict)
arg_both.update(recam_dict)
kwargs.update({'cam_emb': cam_emb})
if input_frames != None:
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
arg_c.update(vace_dict)
arg_null.update(vace_dict)
arg_both.update(vace_dict)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
@ -424,39 +439,68 @@ class WanT2V:
timestep = [t]
offload.set_step_no_for_lora(self.model, i)
timestep = torch.stack(timestep)
kwargs["current_step"] = i
kwargs["t"] = timestep
if joint_pass:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
if phantom:
pos_it, pos_i, neg = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
context = [context, context_null, context_null], **kwargs)
else:
noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
if self._interrupt:
return None
if phantom:
pos_it = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
)[0]
if self._interrupt:
return None
pos_i = self.model(
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
)[0]
if self._interrupt:
return None
neg = self.model(
[torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
)[0]
if self._interrupt:
return None
else:
noise_pred_cond = self.model(
latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
if self._interrupt:
return None
noise_pred_uncond = self.model(
latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
if self._interrupt:
return None
# del latent_model_input
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
noise_pred_text = noise_pred_cond
if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
if phantom:
guide_scale_img= 5.0
guide_scale_text= guide_scale #7.5
noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
else:
noise_pred_text = noise_pred_cond
if cfg_star_switch:
positive_flat = noise_pred_text.view(batch_size, -1)
negative_flat = noise_pred_uncond.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, 1, 1, 1)
if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
else:
noise_pred_uncond *= alpha
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
del noise_pred_uncond
if (i <= cfg_zero_step):
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
else:
noise_pred_uncond *= alpha
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
noise_pred_uncond, noise_pred_cond, noise_pred_text, pos_it, pos_i, neg = None, None, None, None, None, None
temp_x0 = sample_scheduler.step(
noise_pred[:, :target_shape[1]].unsqueeze(0),
@ -473,8 +517,12 @@ class WanT2V:
x0 = latents
if input_frames == None:
if phantom:
# phantom post processing
x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
videos = self.vae.decode(x0, VAE_tile_size)
else:
# vace post processing
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
del latents

View File

@ -69,18 +69,29 @@ def remove_background(img, session=None):
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
if rm_background:
session = new_session()
output_list =[]
for img in img_list:
width, height = img.size
scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16)
new_width = int( round(width * scale / 16) * 16)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if fit_into_canvas:
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
scale = min(budget_height / height, budget_width / width)
new_height = int(height * scale)
new_width = int(width * scale)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
top = (budget_height - new_height) // 2
left = (budget_width - new_width) // 2
white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)
resized_image = Image.fromarray(white_canvas)
else:
scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16)
new_width = int( round(width * scale / 16) * 16)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background:
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image)

304
wgp.py
View File

@ -40,7 +40,7 @@ global_queue_ref = []
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.0"
target_mmgp_version = "3.4.1"
from importlib.metadata import version
mmgp_version = version("mmgp")
if mmgp_version != target_mmgp_version:
@ -133,10 +133,11 @@ def process_prompt_and_add_tasks(state, model_choice):
model_filename = state["model_filename"]
if model_choice != get_model_type(model_filename):
model_type = get_model_type(model_filename)
inputs = state.get(model_type, None)
if model_choice != model_type or inputs ==None:
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
inputs = state.get(get_model_type(model_filename), None)
inputs["state"] = state
inputs.pop("lset_name")
if inputs == None:
@ -176,7 +177,7 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info(f"Resolution {resolution} not supported by image 2 video")
return
if "1.3B" in model_filename and width * height > 848*480:
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
return
@ -186,8 +187,28 @@ def process_prompt_and_add_tasks(state, model_choice):
if video_length > sliding_window_size:
gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
if "phantom" in model_filename:
image_refs = inputs["image_refs"]
if "diffusion_forcing" in model_filename:
if isinstance(image_refs, list):
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from wan.utils.utils import resize_and_remove_background
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1, fit_into_canvas= True)
if len(prompts) > 0:
prompts = ["\n".join(prompts)]
for single_prompt in prompts:
extra_inputs = {
"prompt" : single_prompt,
"image_refs": image_refs,
}
inputs.update(extra_inputs)
add_video_task(**inputs)
elif "diffusion_forcing" in model_filename:
image_start = inputs["image_start"]
video_source = inputs["video_source"]
keep_frames_video_source = inputs["keep_frames_video_source"]
@ -1362,10 +1383,6 @@ quantizeTransformer = args.quantize_transformer
check_loras = args.check_loras ==1
advanced = args.advanced
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors",
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
server_config_filename = "wgp_config.json"
if not os.path.isdir("settings"):
@ -1401,11 +1418,32 @@ else:
text = reader.read()
server_config = json.loads(text)
# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
# if Path(src_path).is_file():
# shutil.move(src_path, tgt_path) )
# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
# if Path(path).is_file():
# os.remove(path)
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B"]
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
os.remove(path)
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
"ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B" }
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
"phantom_1.3B" : "phantom_1.3B", }
def get_model_type(model_filename):
@ -1417,29 +1455,47 @@ def get_model_type(model_filename):
def test_class_i2v(model_filename):
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
def get_model_name(model_filename):
def get_model_name(model_filename, description_container = [""]):
if "Fun" in model_filename:
model_name = "Fun InP image2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model."
elif "Vace" in model_filename:
model_name = "Vace ControlNet"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video."
elif "image" in model_filename:
model_name = "Wan2.1 image2video"
model_name += " 720p" if "720p" in model_filename else " 480p"
if "720p" in model_filename:
description = "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
else:
description = "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
elif "recam" in model_filename:
model_name = "ReCamMaster"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)"
elif "FLF2V" in model_filename:
model_name = "Wan2.1 FLF2V"
model_name += " 720p" if "720p" in model_filename else " 480p"
description = "The First Last Frame 2 Video model is the official model Image 2 Video model that support Start and End frames."
elif "sky_reels2_diffusion_forcing" in model_filename:
model_name = "SkyReels2 diffusion forcing"
model_name = "SkyReels2 Diffusion Forcing"
if "720p" in model_filename :
model_name += " 720p"
elif not "1.3B" in model_filename :
model_name += " 540p"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video."
elif "phantom" in model_filename:
model_name = "Wan2.1 Phantom"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
else:
model_name = "Wan2.1 text2video"
model_name += " 14B" if "14B" in model_filename else " 1.3B"
description = "The original Wan Text 2 Video model. Most other models have been built on top of it"
description_container[0] = description
return model_name
@ -1493,13 +1549,28 @@ def get_default_settings(filename):
"slg_end_perc": 90
}
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B"):
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
ui_defaults.update({
"guidance_scale": 6.0,
"flow_shift": 8,
"sliding_window_discard_last_frames" : 0
"sliding_window_discard_last_frames" : 0,
"resolution": "1280x720" if "720p" in filename else "960x544",
"sliding_window_size" : 121 if "720p" in filename else 97,
"RIFLEx_setting": 2,
"guidance_scale": 6,
"flow_shift": 8,
})
if get_model_type(filename) in ("phantom_1.3B"):
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
"resolution": "1280x720"
})
with open(defaults_filename, "w", encoding="utf-8") as f:
json.dump(ui_defaults, f, indent=4)
else:
@ -1649,7 +1720,7 @@ def download_models(transformer_filename, text_encoder_filename):
from huggingface_hub import hf_hub_download, snapshot_download
repoId = "DeepBeepMeep/Wan2.1"
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
targetRoot = "ckpts/"
for sourceFolder, files in zip(sourceFolderList,fileList ):
if len(files)==0:
@ -1763,12 +1834,12 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
cfg = WAN_CONFIGS['t2v-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
print(f"Loading '{model_filename}' model...")
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B"):
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
model_factory = wan.DTT2V
else:
model_factory = wan.WanT2V
@ -1779,52 +1850,32 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
model_filename=model_filename,
text_encoder_filename= text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
return wan_model, pipe
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
print(f"Loading '{model_filename}' model...")
if value == '720P':
cfg = WAN_CONFIGS['i2v-14B']
wan_model = wan.WanI2V(
config=cfg,
checkpoint_dir="ckpts",
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= True,
model_filename=model_filename,
text_encoder_filename=text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
cfg = WAN_CONFIGS['i2v-14B']
wan_model = wan.WanI2V(
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
text_encoder_filename=text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
elif value == '480P':
cfg = WAN_CONFIGS['i2v-14B']
wan_model = wan.WanI2V(
config=cfg,
checkpoint_dir="ckpts",
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
i2v720p= False,
model_filename=model_filename,
text_encoder_filename=text_encoder_filename,
quantizeTransformer = quantizeTransformer,
dtype = dtype
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
else:
raise Exception("Model i2v {value} not supported")
return wan_model, pipe
@ -1836,18 +1887,22 @@ def load_models(model_filename):
perc_reserved_mem_max = args.perc_reserved_mem_max
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
default_dtype = torch.float16 if major < 8 else torch.bfloat16
# default_dtype = torch.bfloat16
if default_dtype == torch.float16 or args.fp16:
if major < 8:
print("Switching to f16 model as GPU architecture doesn't support bf16")
default_dtype = torch.float16
else:
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
if default_dtype == torch.float16 :
if "quanto" in model_filename:
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
download_models(model_filename, text_encoder_filename)
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
if test_class_i2v(model_filename):
res720P = "720p" in model_filename
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype )
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
else:
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype)
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model._model_file_name = model_filename
kwargs = { "extraModelsToQuantize": None}
if profile == 2 or profile == 4:
@ -1888,8 +1943,13 @@ def get_default_flow(filename, i2v):
def generate_header(model_filename, compile, attention_mode):
header = "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
description_container = [""]
get_model_name(model_filename, description_container)
description = description_container[0]
header = "<DIV style='height:40px'>" + description + "</DIV>"
header += "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
if attention_mode not in attention_modes_installed:
header += " -NOT INSTALLED-"
elif attention_mode not in attention_modes_supported:
@ -1907,6 +1967,8 @@ def generate_header(model_filename, compile, attention_mode):
def apply_changes( state,
transformer_types_choices,
text_encoder_choice,
VAE_precision_choice,
mixed_precision_choice,
save_path_choice,
attention_choice,
compile_choice,
@ -1922,7 +1984,7 @@ def apply_changes( state,
if args.lock_config:
return
if gen_in_progress:
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>", gr.update(), gr.update()
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
server_config = {"attention_mode" : attention_choice,
"transformer_types": transformer_types_choices,
@ -1931,6 +1993,8 @@ def apply_changes( state,
"compile" : compile_choice,
"profile" : profile_choice,
"vae_config" : vae_config_choice,
"vae_precision" : VAE_precision_choice,
"mixed_precision" : mixed_precision_choice,
"metadata_type": metadata_choice,
"transformer_quantization" : quantization_choice,
"boost" : boost_choice,
@ -2052,12 +2116,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps):
return callback
def abort_generation(state):
gen = get_gen_info(state)
if "in_progress" in gen:
if "in_progress" in gen and wan_model != None:
gen["abort"] = True
gen["extra_orders"] = 0
if wan_model != None:
wan_model._interrupt= True
wan_model._interrupt= True
msg = "Processing Request to abort Current Generation"
gen["status"] = msg
gr.Info(msg)
@ -2140,13 +2201,6 @@ def finalize_generation(state):
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
def refresh_gallery_on_trigger(state):
gen = get_gen_info(state)
if(gen.get("update_gallery", False)):
gen['update_gallery'] = False
return gr.update(value=gen.get("file_list", []))
def select_video(state , event_data: gr.EventData):
data= event_data._data
gen = get_gen_info(state)
@ -2385,6 +2439,8 @@ def generate_video(
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
if vae_config == 0:
if server_config.get("vae_precision", "16") == "32":
device_mem_capacity = device_mem_capacity / 2
if device_mem_capacity >= 24000:
use_vae_config = 1
elif device_mem_capacity >= 8000:
@ -2497,6 +2553,7 @@ def generate_video(
max_frames_to_generate = video_length
diffusion_forcing = "diffusion_forcing" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename
if diffusion_forcing or vace:
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
if diffusion_forcing and source_video != None:
@ -2536,6 +2593,7 @@ def generate_video(
extra_windows = 0
guide_start_frame = 0
video_length = first_window_video_length
gen["extra_windows"] = 0
while not abort:
if sliding_window:
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
@ -2550,7 +2608,9 @@ def generate_video(
window_no += 1
gen["window_no"] = window_no
if diffusion_forcing:
if phantom:
src_ref_images = image_refs.copy() if image_refs != None else None
elif diffusion_forcing:
if video_source != None and len(video_source) > 0 and window_no == 1:
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
@ -2559,7 +2619,7 @@ def generate_video(
prefix_video_frames_count = prefix_video.shape[1]
pre_video_guide = prefix_video[:, -reuse_frames:]
if vace:
elif vace:
# video_prompt_type = video_prompt_type +"G"
image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
video_guide_copy = video_guide
@ -2610,7 +2670,7 @@ def generate_video(
progress_args = [0, status + " - Encoding Prompt"]
send_cmd("progress", progress_args)
samples = torch.empty( (1,2)) #for testing
# samples = torch.empty( (1,2)) #for testing
# if False:
try:
@ -2633,7 +2693,6 @@ def generate_video(
guide_scale=guidance_scale,
n_prompt=negative_prompt,
seed=seed,
offload_model=False,
callback=callback,
enable_RIFLEx = enable_RIFLEx,
VAE_tile_size = VAE_tile_size,
@ -2738,6 +2797,7 @@ def generate_video(
if samples == None:
abort = True
state["prompt"] = ""
send_cmd("output")
else:
sample = samples.cpu()
if True: # for testing
@ -2839,7 +2899,6 @@ def generate_video(
print(f"New video saved to Path: "+video_path)
file_list.append(video_path)
state['update_gallery'] = True
send_cmd("output")
if sliding_window :
if max_frames_to_generate > 0 and extra_windows == 0:
@ -2847,8 +2906,6 @@ def generate_video(
if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
break
video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
else:
break
seed += 1
@ -3416,7 +3473,7 @@ def prepare_inputs_dict(target, inputs ):
if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
inputs.pop("model_mode")
if not "Vace" in model_filename:
if not "Vace" in model_filename or not "phantom" in model_filename:
unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
for k in unsaved_params:
inputs.pop(k)
@ -3776,6 +3833,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
diffusion_forcing = "diffusion_forcing" in model_filename
recammaster = "recam" in model_filename
vace = "Vace" in model_filename
phantom = "phantom" in model_filename
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
if diffusion_forcing:
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
@ -3835,23 +3893,27 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
model_mode = gr.Dropdown(visible=False)
keep_frames_video_source = gr.Text(visible=False)
with gr.Column(visible= vace ) as video_prompt_column:
with gr.Column(visible= vace or phantom) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
with gr.Row():
video_prompt_type_video_guide = gr.Dropdown(
choices=[
("None", ""),
("Transfer Human Motion from the Control Video", "PV"),
("Transfer Depth from the Control Video", "DV"),
("Recolorize the Control Video", "CV"),
# ("Alternate Video Ending", "OV"),
("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
("Control Video and Mask video for stronger Inpainting ", "MV"),
],
value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video", scale = 3
)
if vace:
video_prompt_type_video_guide = gr.Dropdown(
choices=[
("None", ""),
("Transfer Human Motion from the Control Video", "PV"),
("Transfer Depth from the Control Video", "DV"),
("Recolorize the Control Video", "CV"),
# ("Alternate Video Ending", "OV"),
("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
("Control Video and Mask video for stronger Inpainting ", "MV"),
],
value=filter_letters(video_prompt_type_value, "ODPCMV"),
label="Video to Video", scale = 3, visible= True
)
else:
video_prompt_type_video_guide = gr.Dropdown(visible= False)
video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
video_prompt_type_image_refs = gr.Dropdown(
@ -3869,7 +3931,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_refs = gr.Gallery( label ="Reference Images",
type ="pil", show_label= True,
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
value= ui_defaults.get("image_refs", None) )
value= ui_defaults.get("image_refs", None),
)
# with gr.Row():
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
@ -3929,7 +3992,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
# ("832x1104 (3:4, 720p)", "832x1104"),
# ("960x960 (1:1, 720p)", "960x960"),
# 480p
# ("960x544 (16:9, 480p)", "960x544"),
("960x544 (16:9, 540p)", "960x544"),
("544x960 (16:9, 540p)", "544x960"),
("832x480 (16:9, 480p)", "832x480"),
("480x832 (9:16, 480p)", "480x832"),
# ("832x624 (4:3, 540p)", "832x624"),
@ -4082,13 +4146,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
if diffusion_forcing:
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size")
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=1, visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
else:
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 16), step=1, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
@ -4429,13 +4493,22 @@ def generate_configuration_tab(state, blocks, header, model_choice):
quantization_choice = gr.Dropdown(
choices=[
("Int8 Quantization (recommended)", "int8"),
("Scaled Int8 Quantization (recommended)", "int8"),
("16 bits (no quantization)", "bf16"),
],
value= transformer_quantization,
label="Wan Transformer Model Quantization Type (if available)",
)
mixed_precision_choice = gr.Dropdown(
choices=[
("16 bits only, requires less VRAM", "0"),
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
],
value= server_config.get("mixed_precision", "0"),
label="Transformer Engine Calculation"
)
index = text_encoder_choices.index(text_encoder_filename)
index = 0 if index ==0 else index
text_encoder_choice = gr.Dropdown(
@ -4446,6 +4519,16 @@ def generate_configuration_tab(state, blocks, header, model_choice):
value= index,
label="Text Encoder model"
)
VAE_precision_choice = gr.Dropdown(
choices=[
("16 bits, requires less VRAM and faster", "16"),
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
],
value= server_config.get("vae_precision", "16"),
label="VAE Encoding / Decoding precision"
)
save_path_choice = gr.Textbox(
label="Output Folder for Generated Videos",
value=server_config.get("save_path", save_path)
@ -4510,14 +4593,7 @@ def generate_configuration_tab(state, blocks, header, model_choice):
value= profile,
label="Profile (for power users only, not needed to change it)"
)
# default_ui_choice = gr.Dropdown(
# choices=[
# ("Text to Video", "t2v"),
# ("Image to Video", "i2v"),
# ],
# value= default_ui,
# label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
# )
metadata_choice = gr.Dropdown(
choices=[
("Export JSON files", "json"),
@ -4563,6 +4639,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
state,
transformer_types_choices,
text_encoder_choice,
VAE_precision_choice,
mixed_precision_choice,
save_path_choice,
attention_choice,
compile_choice,
@ -4957,7 +5035,7 @@ def create_demo():
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.3 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
tab_state = gr.State({ "tab_no":0 })