mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
New Multitabs, Save Settings, End Frame
This commit is contained in:
parent
6c8cd8b163
commit
f2c5a06626
14
README.md
14
README.md
@ -19,7 +19,14 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to Tophness for his contributionsgit. You will need one more *pip install -r requirements.txt* to reflect new dependencies\
|
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
||||||
|
- New Tab based interface, yon can switch from i2v to t2v conversely without restarting the app
|
||||||
|
- Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
|
||||||
|
- You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
|
||||||
|
- Slight acceleration with loras
|
||||||
|
You will need one more *pip install -r requirements.txt*
|
||||||
|
Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
|
||||||
|
* Mar 18 2022: 👋 Wan2.1GP v2.11: Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions. You will need one more *pip install -r requirements.txt* to reflect new dependencies\
|
||||||
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
|
* Mar 18 2022: 👋 Wan2.1GP v2.1: More Loras !: added support for 'Safetensors' and 'Replicate' Lora formats.\
|
||||||
You will need to refresh the requirements with a *pip install -r requirements.txt*
|
You will need to refresh the requirements with a *pip install -r requirements.txt*
|
||||||
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
* Mar 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
||||||
@ -243,11 +250,12 @@ You can define multiple lines of macros. If there is only one macro line, the ap
|
|||||||
--seed no : set default seed value\
|
--seed no : set default seed value\
|
||||||
--frames no : set the default number of frames to generate\
|
--frames no : set the default number of frames to generate\
|
||||||
--steps no : set the default number of denoising steps\
|
--steps no : set the default number of denoising steps\
|
||||||
--res resolution : default resolution, choices=["480p", "720p", "823p", "1024p", "1280p"]\
|
|
||||||
--teacache speed multiplier: Tea cache speed multiplier, choices=["0", "1.5", "1.75", "2.0", "2.25", "2.5"]\
|
--teacache speed multiplier: Tea cache speed multiplier, choices=["0", "1.5", "1.75", "2.0", "2.25", "2.5"]\
|
||||||
--slg : turn on skip layer guidance for improved quality\
|
--slg : turn on skip layer guidance for improved quality\
|
||||||
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
||||||
--advanced : turn on the advanced mode while launching the app
|
--advanced : turn on the advanced mode while launching the app\
|
||||||
|
--i2v-settings : path to launch settings for i2v\
|
||||||
|
--t2v-settings : path to launch settings for t2v
|
||||||
|
|
||||||
### Profiles (for power users only)
|
### Profiles (for power users only)
|
||||||
You can choose between 5 profiles, but two are really relevant here :
|
You can choose between 5 profiles, but two are really relevant here :
|
||||||
|
|||||||
643
gradio_server.py
643
gradio_server.py
File diff suppressed because it is too large
Load Diff
@ -16,6 +16,6 @@ gradio>=5.0.0
|
|||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
einops
|
einops
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
mmgp==3.2.8
|
mmgp==3.3.0
|
||||||
peft==0.14.0
|
peft==0.14.0
|
||||||
mutagen
|
mutagen
|
||||||
@ -26,6 +26,82 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|||||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
from wan.modules.posemb_layers import get_rotary_pos_embed
|
from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def lanczos(samples, width, height):
|
||||||
|
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
||||||
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||||
|
result = torch.stack(images)
|
||||||
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
|
def bislerp(samples, width, height):
|
||||||
|
def slerp(b1, b2, r):
|
||||||
|
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
||||||
|
|
||||||
|
c = b1.shape[-1]
|
||||||
|
|
||||||
|
#norms
|
||||||
|
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
||||||
|
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
#normalize
|
||||||
|
b1_normalized = b1 / b1_norms
|
||||||
|
b2_normalized = b2 / b2_norms
|
||||||
|
|
||||||
|
#zero when norms are zero
|
||||||
|
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
||||||
|
|
||||||
|
#slerp
|
||||||
|
dot = (b1_normalized*b2_normalized).sum(1)
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
|
||||||
|
#technically not mathematically correct, but more pleasing?
|
||||||
|
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
||||||
|
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
||||||
|
|
||||||
|
#edge cases for same or polar opposites
|
||||||
|
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
||||||
|
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
|
orig_shape = tuple(samples.shape)
|
||||||
|
if len(orig_shape) > 4:
|
||||||
|
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
||||||
|
samples = samples.movedim(2, 1)
|
||||||
|
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
||||||
|
if crop == "center":
|
||||||
|
old_width = samples.shape[-1]
|
||||||
|
old_height = samples.shape[-2]
|
||||||
|
old_aspect = old_width / old_height
|
||||||
|
new_aspect = width / height
|
||||||
|
x = 0
|
||||||
|
y = 0
|
||||||
|
if old_aspect > new_aspect:
|
||||||
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
||||||
|
elif old_aspect < new_aspect:
|
||||||
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
||||||
|
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
||||||
|
else:
|
||||||
|
s = samples
|
||||||
|
|
||||||
|
if upscale_method == "bislerp":
|
||||||
|
out = bislerp(s, width, height)
|
||||||
|
elif upscale_method == "lanczos":
|
||||||
|
out = lanczos(s, width, height)
|
||||||
|
else:
|
||||||
|
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
||||||
|
|
||||||
|
if len(orig_shape) == 4:
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
|
||||||
|
return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
|
||||||
|
|
||||||
class WanI2V:
|
class WanI2V:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -63,8 +139,8 @@ class WanI2V:
|
|||||||
Enable distribution strategy of USP.
|
Enable distribution strategy of USP.
|
||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
|
||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -134,6 +210,7 @@ class WanI2V:
|
|||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
img,
|
img,
|
||||||
|
img2 = None,
|
||||||
max_area=720 * 1280,
|
max_area=720 * 1280,
|
||||||
frame_num=81,
|
frame_num=81,
|
||||||
shift=5.0,
|
shift=5.0,
|
||||||
@ -188,8 +265,14 @@ class WanI2V:
|
|||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
||||||
|
any_end_frame = img2 !=None
|
||||||
|
if any_end_frame:
|
||||||
|
any_end_frame = True
|
||||||
|
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
frame_num +=1
|
||||||
|
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||||
|
|
||||||
F = frame_num
|
|
||||||
h, w = img.shape[1:]
|
h, w = img.shape[1:]
|
||||||
aspect_ratio = h / w
|
aspect_ratio = h / w
|
||||||
lat_h = round(
|
lat_h = round(
|
||||||
@ -201,28 +284,21 @@ class WanI2V:
|
|||||||
h = lat_h * self.vae_stride[1]
|
h = lat_h * self.vae_stride[1]
|
||||||
w = lat_w * self.vae_stride[2]
|
w = lat_w * self.vae_stride[2]
|
||||||
|
|
||||||
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
||||||
self.patch_size[1] * self.patch_size[2])
|
|
||||||
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
||||||
|
|
||||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||||
seed_g = torch.Generator(device=self.device)
|
seed_g = torch.Generator(device=self.device)
|
||||||
seed_g.manual_seed(seed)
|
seed_g.manual_seed(seed)
|
||||||
noise = torch.randn(
|
noise = torch.randn(16, lat_frames, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device)
|
||||||
16,
|
|
||||||
int((frame_num - 1)/4 + 1), #21,
|
|
||||||
lat_h,
|
|
||||||
lat_w,
|
|
||||||
dtype=torch.float32,
|
|
||||||
generator=seed_g,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
|
||||||
|
if any_end_frame:
|
||||||
|
msk[:, 1: -1] = 0
|
||||||
|
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:-1], torch.repeat_interleave(msk[:, -1:], repeats=4, dim=1) ], dim=1)
|
||||||
|
else:
|
||||||
msk[:, 1:] = 0
|
msk[:, 1:] = 0
|
||||||
msk = torch.concat([
|
msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1)
|
||||||
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
|
||||||
],
|
|
||||||
dim=1)
|
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
|
||||||
@ -242,7 +318,6 @@ class WanI2V:
|
|||||||
context = [t.to(self.device) for t in context]
|
context = [t.to(self.device) for t in context]
|
||||||
context_null = [t.to(self.device) for t in context_null]
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
# self.clip.model.to(self.device)
|
|
||||||
clip_context = self.clip.visual([img[:, None, :, :]])
|
clip_context = self.clip.visual([img[:, None, :, :]])
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.clip.model.cpu()
|
self.clip.model.cpu()
|
||||||
@ -250,16 +325,24 @@ class WanI2V:
|
|||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
offload.last_offload_obj.unload_all()
|
offload.last_offload_obj.unload_all()
|
||||||
|
if any_end_frame:
|
||||||
|
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
||||||
|
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
||||||
|
mean2 = 0
|
||||||
|
enc= torch.concat([
|
||||||
|
img_interpolated,
|
||||||
|
torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
|
||||||
|
img2_interpolated,
|
||||||
|
], dim=1).to(self.device)
|
||||||
|
else:
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
torch.nn.functional.interpolate(
|
torch.nn.functional.interpolate(
|
||||||
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
|
||||||
0, 1).to(torch.bfloat16),
|
|
||||||
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
||||||
], dim=1).to(self.device)
|
], dim=1).to(self.device)
|
||||||
# enc = None
|
|
||||||
|
|
||||||
y = self.vae.encode([enc], VAE_tile_size)[0]
|
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame)[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, lat_y])
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
@ -293,7 +376,7 @@ class WanI2V:
|
|||||||
# sample videos
|
# sample videos
|
||||||
latent = noise
|
latent = noise
|
||||||
|
|
||||||
freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx )
|
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
|
|
||||||
arg_c = {
|
arg_c = {
|
||||||
'context': [context[0]],
|
'context': [context[0]],
|
||||||
@ -344,8 +427,6 @@ class WanI2V:
|
|||||||
|
|
||||||
timestep = torch.stack(timestep).to(self.device)
|
timestep = torch.stack(timestep).to(self.device)
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
# if slg_layers is not None:
|
|
||||||
# raise ValueError('Can not use SLG and joint-pass')
|
|
||||||
noise_pred_cond, noise_pred_uncond = self.model(
|
noise_pred_cond, noise_pred_uncond = self.model(
|
||||||
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
@ -396,14 +477,22 @@ class WanI2V:
|
|||||||
callback(i, latent)
|
callback(i, latent)
|
||||||
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device)]
|
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
videos = self.vae.decode(x0, VAE_tile_size)
|
# x0 = [lat_y]
|
||||||
|
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame)[0]
|
||||||
|
|
||||||
|
if any_end_frame:
|
||||||
|
# video[:, -1:] = img2_interpolated
|
||||||
|
video = video[:, :-1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
video = None
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
@ -413,4 +502,4 @@ class WanI2V:
|
|||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
return videos[0] if self.rank == 0 else None
|
return video
|
||||||
|
|||||||
@ -429,11 +429,10 @@ def get_1d_rotary_pos_embed(
|
|||||||
) # complex64 # [S, D/2]
|
) # complex64 # [S, D/2]
|
||||||
return freqs_cis
|
return freqs_cis
|
||||||
|
|
||||||
def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
|
def get_rotary_pos_embed(latents_size, enable_RIFLEx = False):
|
||||||
target_ndim = 3
|
target_ndim = 3
|
||||||
ndim = 5 - 2
|
ndim = 5 - 2
|
||||||
|
|
||||||
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
|
||||||
patch_size = [1, 2, 2]
|
patch_size = [1, 2, 2]
|
||||||
if isinstance(patch_size, int):
|
if isinstance(patch_size, int):
|
||||||
assert all(s % patch_size == 0 for s in latents_size), (
|
assert all(s % patch_size == 0 for s in latents_size), (
|
||||||
@ -468,7 +467,7 @@ def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
|
|||||||
theta=10000,
|
theta=10000,
|
||||||
use_real=True,
|
use_real=True,
|
||||||
theta_rescale_factor=1,
|
theta_rescale_factor=1,
|
||||||
L_test = (video_length - 1) // 4 + 1,
|
L_test = latents_size[0],
|
||||||
enable_riflex = enable_RIFLEx
|
enable_riflex = enable_RIFLEx
|
||||||
)
|
)
|
||||||
return (freqs_cos, freqs_sin)
|
return (freqs_cos, freqs_sin)
|
||||||
@ -530,26 +530,37 @@ class WanVAE_(nn.Module):
|
|||||||
x_recon = self.decode(z)
|
x_recon = self.decode(z)
|
||||||
return x_recon, mu, log_var
|
return x_recon, mu, log_var
|
||||||
|
|
||||||
def encode(self, x, scale = None):
|
def encode(self, x, scale = None, any_end_frame = False):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
## cache
|
## cache
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
|
if any_end_frame:
|
||||||
|
iter_ = 2 + (t - 2) // 4
|
||||||
|
else:
|
||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||||
|
out_list = []
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(
|
out_list.append(self.encoder(
|
||||||
x[:, :, :1, :, :],
|
x[:, :, :1, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx))
|
||||||
|
elif any_end_frame and i== iter_ -1:
|
||||||
|
out_list.append(self.encoder(
|
||||||
|
x[:, :, -1:, :, :],
|
||||||
|
feat_cache= None,
|
||||||
|
feat_idx=self._enc_conv_idx))
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_list.append(self.encoder(
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx))
|
||||||
out = torch.cat([out, out_], 2)
|
|
||||||
|
|
||||||
|
self.clear_cache()
|
||||||
|
out = torch.cat(out_list, 2)
|
||||||
|
out_list = None
|
||||||
|
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
if scale != None:
|
if scale != None:
|
||||||
@ -558,11 +569,10 @@ class WanVAE_(nn.Module):
|
|||||||
1, self.z_dim, 1, 1, 1)
|
1, self.z_dim, 1, 1, 1)
|
||||||
else:
|
else:
|
||||||
mu = (mu - scale[0]) * scale[1]
|
mu = (mu - scale[0]) * scale[1]
|
||||||
self.clear_cache()
|
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
|
||||||
def decode(self, z, scale=None):
|
def decode(self, z, scale=None, any_end_frame = False):
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
# z: [b,c,t,h,w]
|
# z: [b,c,t,h,w]
|
||||||
if scale != None:
|
if scale != None:
|
||||||
@ -573,20 +583,26 @@ class WanVAE_(nn.Module):
|
|||||||
z = z / scale[1] + scale[0]
|
z = z / scale[1] + scale[0]
|
||||||
iter_ = z.shape[2]
|
iter_ = z.shape[2]
|
||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
|
out_list = []
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(
|
out_list.append(self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=self._conv_idx))
|
||||||
|
elif any_end_frame and i==iter_-1:
|
||||||
|
out_list.append(self.decoder(
|
||||||
|
x[:, :, -1:, :, :],
|
||||||
|
feat_cache=None ,
|
||||||
|
feat_idx=self._conv_idx))
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_list.append(self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=self._conv_idx))
|
||||||
out = torch.cat([out, out_], 2)
|
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
out = torch.cat(out_list, 2)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||||
@ -601,7 +617,7 @@ class WanVAE_(nn.Module):
|
|||||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def spatial_tiled_decode(self, z, scale, tile_size):
|
def spatial_tiled_decode(self, z, scale, tile_size, any_end_frame= False):
|
||||||
tile_sample_min_size = tile_size
|
tile_sample_min_size = tile_size
|
||||||
tile_latent_min_size = int(tile_sample_min_size / 8)
|
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||||
tile_overlap_factor = 0.25
|
tile_overlap_factor = 0.25
|
||||||
@ -626,7 +642,7 @@ class WanVAE_(nn.Module):
|
|||||||
row = []
|
row = []
|
||||||
for j in range(0, z.shape[-1], overlap_size):
|
for j in range(0, z.shape[-1], overlap_size):
|
||||||
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
|
tile = z[:, :, :, i: i + tile_latent_min_size, j: j + tile_latent_min_size]
|
||||||
decoded = self.decode(tile)
|
decoded = self.decode(tile, any_end_frame= any_end_frame)
|
||||||
row.append(decoded)
|
row.append(decoded)
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
result_rows = []
|
result_rows = []
|
||||||
@ -645,7 +661,7 @@ class WanVAE_(nn.Module):
|
|||||||
return torch.cat(result_rows, dim=-2)
|
return torch.cat(result_rows, dim=-2)
|
||||||
|
|
||||||
|
|
||||||
def spatial_tiled_encode(self, x, scale, tile_size) :
|
def spatial_tiled_encode(self, x, scale, tile_size, any_end_frame = False) :
|
||||||
tile_sample_min_size = tile_size
|
tile_sample_min_size = tile_size
|
||||||
tile_latent_min_size = int(tile_sample_min_size / 8)
|
tile_latent_min_size = int(tile_sample_min_size / 8)
|
||||||
tile_overlap_factor = 0.25
|
tile_overlap_factor = 0.25
|
||||||
@ -660,7 +676,7 @@ class WanVAE_(nn.Module):
|
|||||||
row = []
|
row = []
|
||||||
for j in range(0, x.shape[-1], overlap_size):
|
for j in range(0, x.shape[-1], overlap_size):
|
||||||
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
|
tile = x[:, :, :, i: i + tile_sample_min_size, j: j + tile_sample_min_size]
|
||||||
tile = self.encode(tile)
|
tile = self.encode(tile, any_end_frame= any_end_frame)
|
||||||
row.append(tile)
|
row.append(tile)
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
result_rows = []
|
result_rows = []
|
||||||
@ -764,18 +780,18 @@ class WanVAE:
|
|||||||
z_dim=z_dim,
|
z_dim=z_dim,
|
||||||
).eval().requires_grad_(False).to(device)
|
).eval().requires_grad_(False).to(device)
|
||||||
|
|
||||||
def encode(self, videos, tile_size = 256):
|
def encode(self, videos, tile_size = 256, any_end_frame = False):
|
||||||
"""
|
"""
|
||||||
videos: A list of videos each with shape [C, T, H, W].
|
videos: A list of videos each with shape [C, T, H, W].
|
||||||
"""
|
"""
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size).float().squeeze(0) for u in videos ]
|
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) for u in videos ]
|
return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
||||||
|
|
||||||
|
|
||||||
def decode(self, zs, tile_size):
|
def decode(self, zs, tile_size, any_end_frame = False):
|
||||||
if tile_size > 0:
|
if tile_size > 0:
|
||||||
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size).float().clamp_(-1, 1).squeeze(0) for u in zs ]
|
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
else:
|
else:
|
||||||
return [ self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs ]
|
return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
||||||
|
|||||||
@ -241,7 +241,7 @@ class WanT2V:
|
|||||||
# sample videos
|
# sample videos
|
||||||
latents = noise
|
latents = noise
|
||||||
|
|
||||||
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
||||||
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user