New Multitabs, Save Settings, End Frame

This commit is contained in:
DeepBeepMeep 2025-03-24 01:00:52 +01:00
parent 6c8cd8b163
commit f2c5a06626
7 changed files with 599 additions and 306 deletions

View File

@ -19,7 +19,14 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!! ## 🔥 Latest News!!
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to Tophness for his contributionsgit. You will need one more *pip install -r requirements.txt* to reflect new dependencies\ * Mar 18 2022: 👋 Wan2.1GP v3.0:
- New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app
- Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
- You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
- Slight acceleration with loras
You will need one more *pip install -r requirements.txt*
Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\ * Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
You will need to refresh the requirements with a *pip install -r requirements.txt* You will need to refresh the requirements with a *pip install -r requirements.txt*
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues: * Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
@ -243,11 +250,12 @@ You can define multiple lines of macros. If there is only one macro line, the ap
--seed no : set default seed value\ --seed no : set default seed value\
--frames no : set the default number of frames to generate\ --frames no : set the default number of frames to generate\
--steps no : set the default number of denoising steps\ --steps no : set the default number of denoising steps\
--res resolution : default resolution, choices=["480p", "720p", "823p", "1024p", "1280p"]\
--teacache speed multiplier: Tea cache speed multiplier, choices=["0", "1.5", "1.75", "2.0", "2.25", "2.5"]\ --teacache speed multiplier: Tea cache speed multiplier, choices=["0", "1.5", "1.75", "2.0", "2.25", "2.5"]\
--slg : turn on skip layer guidance for improved quality\ --slg : turn on skip layer guidance for improved quality\
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\ --check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
--advanced : turn on the advanced mode while launching the app --advanced : turn on the advanced mode while launching the app\
--i2v-settings : path to launch settings for i2v\
--t2v-settings : path to launch settings for t2v
### Profiles (for power users only) ### Profiles (for power users only)
You can choose between 5 profiles, but two are really relevant here : You can choose between 5 profiles, but two are really relevant here :

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,6 @@ gradio>=5.0.0
numpy>=1.23.5,<2 numpy>=1.23.5,<2
einops einops
moviepy==1.0.3 moviepy==1.0.3
mmgp==3.2.8 mmgp==3.3.0
peft==0.14.0 peft==0.14.0
mutagen mutagen

View File

@ -26,6 +26,82 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from wan.modules.posemb_layers import get_rotary_pos_embed from wan.modules.posemb_layers import get_rotary_pos_embed
from PIL import Image
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result.to(samples.device, samples.dtype)
def bislerp(samples, width, height):
def slerp(b1, b2, r):
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
c = b1.shape[-1]
#norms
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
#normalize
b1_normalized = b1 / b1_norms
b2_normalized = b2 / b2_norms
#zero when norms are zero
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
#slerp
dot = (b1_normalized*b2_normalized).sum(1)
omega = torch.acos(dot)
so = torch.sin(omega)
#technically not mathematically correct, but more pleasing?
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
#edge cases for same or polar opposites
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res
def common_upscale(samples, width, height, upscale_method, crop):
orig_shape = tuple(samples.shape)
if len(orig_shape) > 4:
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
samples = samples.movedim(2, 1)
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
if crop == "center":
old_width = samples.shape[-1]
old_height = samples.shape[-2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
else:
s = samples
if upscale_method == "bislerp":
out = bislerp(s, width, height)
elif upscale_method == "lanczos":
out = lanczos(s, width, height)
else:
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
if len(orig_shape) == 4:
return out
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
class WanI2V: class WanI2V:
def __init__( def __init__(
@ -63,8 +139,8 @@ class WanI2V:
Enable distribution strategy of USP. Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False): t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp. Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP. Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
init_on_cpu (`bool`, *optional*, defaults to True):
""" """
self.device = torch.device(f"cuda:{device_id}") self.device = torch.device(f"cuda:{device_id}")
self.config = config self.config = config
@ -134,6 +210,7 @@ class WanI2V:
def generate(self, def generate(self,
input_prompt, input_prompt,
img, img,
img2 = None,
max_area=720 * 1280, max_area=720 * 1280,
frame_num=81, frame_num=81,
shift=5.0, shift=5.0,
@ -188,8 +265,14 @@ class WanI2V:
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
any_end_frame = img2 !=None
if any_end_frame:
any_end_frame = True
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
F = frame_num
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
lat_h = round( lat_h = round(
@ -201,28 +284,21 @@ class WanI2V:
h = lat_h * self.vae_stride[1] h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2] w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device) seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed) seed_g.manual_seed(seed)
noise = torch.randn( noise = torch.randn(16, lat_frames, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
16,
int((frame_num - 1)/4 + 1), #21,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
if any_end_frame:
msk[:, 1: -1] = 0
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
else:
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([ msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
@ -242,7 +318,6 @@ class WanI2V:
context = [t.to(self.device) for t in context] context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null] context_null = [t.to(self.device) for t in context_null]
# self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]]) clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model: if offload_model:
self.clip.model.cpu() self.clip.model.cpu()
@ -250,16 +325,24 @@ class WanI2V:
from mmgp import offload from mmgp import offload
offload.last_offload_obj.unload_all() offload.last_offload_obj.unload_all()
if any_end_frame:
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
mean2 = 0
enc= torch.concat([
img_interpolated,
torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
img2_interpolated,
], dim=1).to(self.device)
else:
enc= torch.concat([ enc= torch.concat([
torch.nn.functional.interpolate( torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose( img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
0, 1).to(torch.bfloat16),
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16) torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
], dim=1).to(self.device) ], dim=1).to(self.device)
# enc = None
y = self.vae.encode([enc], VAE_tile_size)[0] lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame)[0]
y = torch.concat([msk, y]) y = torch.concat([msk, lat_y])
@contextmanager @contextmanager
def noop_no_sync(): def noop_no_sync():
@ -293,7 +376,7 @@ class WanI2V:
# sample videos # sample videos
latent = noise latent = noise
freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx ) freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = { arg_c = {
'context': [context[0]], 'context': [context[0]],
@ -344,8 +427,6 @@ class WanI2V:
timestep = torch.stack(timestep).to(self.device) timestep = torch.stack(timestep).to(self.device)
if joint_pass: if joint_pass:
# if slg_layers is not None:
# raise ValueError('Can not use SLG and joint-pass')
noise_pred_cond, noise_pred_uncond = self.model( noise_pred_cond, noise_pred_uncond = self.model(
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
if self._interrupt: if self._interrupt:
@ -396,14 +477,22 @@ class WanI2V:
callback(i, latent) callback(i, latent)
x0 = [latent.to(self.device)] x0 = [latent.to(self.device, dtype=torch.bfloat16)]
if offload_model: if offload_model:
self.model.cpu() self.model.cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.rank == 0: if self.rank == 0:
videos = self.vae.decode(x0, VAE_tile_size) # x0 = [lat_y]
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame)[0]
if any_end_frame:
# video[:, -1:] = img2_interpolated
video = video[:, :-1]
else:
video = None
del noise, latent del noise, latent
del sample_scheduler del sample_scheduler
@ -413,4 +502,4 @@ class WanI2V:
if dist.is_initialized(): if dist.is_initialized():
dist.barrier() dist.barrier()
return videos[0] if self.rank == 0 else None return video

View File

@ -429,11 +429,10 @@ def get_1d_rotary_pos_embed(
) # complex64 # [S, D/2] ) # complex64 # [S, D/2]
return freqs_cis return freqs_cis
def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False): def get_rotary_pos_embed(latents_size, enable_RIFLEx = False):
target_ndim = 3 target_ndim = 3
ndim = 5 - 2 ndim = 5 - 2
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
patch_size = [1, 2, 2] patch_size = [1, 2, 2]
if isinstance(patch_size, int): if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), ( assert all(s % patch_size == 0 for s in latents_size), (
@ -468,7 +467,7 @@ def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
theta=10000, theta=10000,
use_real=True, use_real=True,
theta_rescale_factor=1, theta_rescale_factor=1,
L_test = (video_length - 1) // 4 + 1, L_test = latents_size[0],
enable_riflex = enable_RIFLEx enable_riflex = enable_RIFLEx
) )
return (freqs_cos, freqs_sin) return (freqs_cos, freqs_sin)

View File

@ -530,26 +530,37 @@ class WanVAE_(nn.Module):
x_recon = self.decode(z) x_recon = self.decode(z)
return x_recon, mu, log_var return x_recon, mu, log_var
def encode(self, x, scale = None): def encode(self, x, scale = None, any_end_frame = False):
self.clear_cache() self.clear_cache()
## cache ## cache
t = x.shape[2] t = x.shape[2]
if any_end_frame:
iter_ = 2 + (t - 2) // 4
else:
iter_ = 1 + (t - 1) // 4 iter_ = 1 + (t - 1) // 4
## 对encode输入的x按时间拆分为1、4、4、4.... ## 对encode输入的x按时间拆分为1、4、4、4....
out_list = []
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out_list.append(self.encoder(
x[:, :, :1, :, :], x[:, :, :1, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx))
elif any_end_frame and i== iter_ -1:
out_list.append(self.encoder(
x[:, :, -1:, :, :],
feat_cache= None,
feat_idx=self._enc_conv_idx))
else: else:
out_ = self.encoder( out_list.append(self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map, feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx) feat_idx=self._enc_conv_idx))
out = torch.cat([out, out_], 2)
self.clear_cache()
out = torch.cat(out_list, 2)
out_list = None
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
if scale != None: if scale != None:
@ -558,11 +569,10 @@ class WanVAE_(nn.Module):
1, self.z_dim, 1, 1, 1) 1, self.z_dim, 1, 1, 1)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu return mu
def decode(self, z, scale=None): def decode(self, z, scale=None, any_end_frame = False):
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if scale != None: if scale != None:
@ -573,20 +583,26 @@ class WanVAE_(nn.Module):
z = z / scale[1] + scale[0] z = z / scale[1] + scale[0]
iter_ = z.shape[2] iter_ = z.shape[2]
x = self.conv2(z) x = self.conv2(z)
out_list = []
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out_list.append(self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx) feat_idx=self._conv_idx))
elif any_end_frame and i==iter_-1:
out_list.append(self.decoder(
x[:, :, -1:, :, :],
feat_cache=None ,
feat_idx=self._conv_idx))
else: else:
out_ = self.decoder( out_list.append(self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map, feat_cache=self._feat_map,
feat_idx=self._conv_idx) feat_idx=self._conv_idx))
out = torch.cat([out, out_], 2)
self.clear_cache() self.clear_cache()
out = torch.cat(out_list, 2)
return out return out
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
@ -601,7 +617,7 @@ class WanVAE_(nn.Module):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b return b
def spatial_tiled_decode(self, z, scale, tile_size): def spatial_tiled_decode(self, z, scale, tile_size, any_end_frame= False):
tile_sample_min_size = tile_size tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8) tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25 tile_overlap_factor = 0.25
@ -626,7 +642,7 @@ class WanVAE_(nn.Module):
row = [] row = []
for j in range(0, z.shape[-1], overlap_size): for j in range(0, z.shape[-1], overlap_size):
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size] tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
decoded = self.decode(tile) decoded = self.decode(tile, any_end_frame= any_end_frame)
row.append(decoded) row.append(decoded)
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
@ -645,7 +661,7 @@ class WanVAE_(nn.Module):
return torch.cat(result_rows, dim=-2) return torch.cat(result_rows, dim=-2)
def spatial_tiled_encode(self, x, scale, tile_size) : def spatial_tiled_encode(self, x, scale, tile_size, any_end_frame = False) :
tile_sample_min_size = tile_size tile_sample_min_size = tile_size
tile_latent_min_size = int(tile_sample_min_size / 8) tile_latent_min_size = int(tile_sample_min_size / 8)
tile_overlap_factor = 0.25 tile_overlap_factor = 0.25
@ -660,7 +676,7 @@ class WanVAE_(nn.Module):
row = [] row = []
for j in range(0, x.shape[-1], overlap_size): for j in range(0, x.shape[-1], overlap_size):
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size] tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
tile = self.encode(tile) tile = self.encode(tile, any_end_frame= any_end_frame)
row.append(tile) row.append(tile)
rows.append(row) rows.append(row)
result_rows = [] result_rows = []
@ -764,18 +780,18 @@ class WanVAE:
z_dim=z_dim, z_dim=z_dim,
).eval().requires_grad_(False).to(device) ).eval().requires_grad_(False).to(device)
def encode(self, videos, tile_size = 256): def encode(self, videos, tile_size = 256, any_end_frame = False):
""" """
videos: A list of videos each with shape [C, T, H, W]. videos: A list of videos each with shape [C, T, H, W].
""" """
if tile_size > 0: if tile_size > 0:
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ] return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
else: else:
return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ] return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
def decode(self, zs, tile_size): def decode(self, zs, tile_size, any_end_frame = False):
if tile_size > 0: if tile_size > 0:
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ] return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
else: else:
return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ] return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]

View File

@ -241,7 +241,7 @@ class WanT2V:
# sample videos # sample videos
latents = noise latents = noise
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx) freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self} arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}