mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-21 06:32:07 +00:00
Compare commits
2 Commits
a76ccc9d36
...
22f4714dc1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22f4714dc1 | ||
|
|
28919e5663 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -21,7 +21,7 @@
|
||||
*.html
|
||||
*.pdf
|
||||
*.whl
|
||||
*cache/
|
||||
cache
|
||||
__pycache__/
|
||||
storage/
|
||||
samples/
|
||||
@ -29,11 +29,9 @@ samples/
|
||||
!requirements.txt
|
||||
.DS_Store
|
||||
*DS_Store
|
||||
.vscode
|
||||
google/
|
||||
Wan2.1-T2V-14B/
|
||||
Wan2.1-T2V-1.3B/
|
||||
Wan2.1-I2V-14B-480P/
|
||||
Wan2.1-I2V-14B-720P/
|
||||
poetry.lock
|
||||
logs/
|
||||
|
||||
234
generate.py
234
generate.py
@ -25,12 +25,8 @@ import wan
|
||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||
from wan.utils.platform import (
|
||||
get_device_type,
|
||||
get_torch_distributed_backend,
|
||||
get_torch_profiler_activities,
|
||||
)
|
||||
|
||||
from wan.utils.platform import get_torch_distributed_backend
|
||||
from wan.utils.chrono_inspector import ChronoInspector
|
||||
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
@ -256,11 +252,6 @@ def _parse_args():
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Classifier free guidance scale.")
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="profile the generating procedure.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -281,26 +272,6 @@ def _init_logging(rank):
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
def _init_profiler():
|
||||
return torch.profiler.profile(
|
||||
activities=get_torch_profiler_activities(),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
|
||||
|
||||
def _finalize_profiler(profiler):
|
||||
profiler.stop()
|
||||
table = profiler.key_averages().table(
|
||||
sort_by=f"{get_device_type()}_time_total",
|
||||
row_limit=20,
|
||||
)
|
||||
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
|
||||
f.write(table)
|
||||
|
||||
|
||||
def generate(args):
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
@ -367,10 +338,6 @@ def generate(args):
|
||||
base_seed = [args.base_seed] if rank == 0 else [None]
|
||||
dist.broadcast_object_list(base_seed, src=0)
|
||||
args.base_seed = base_seed[0]
|
||||
|
||||
profiler = None
|
||||
if args.profile and rank == 0:
|
||||
profiler = _init_profiler()
|
||||
|
||||
if "t2v" in args.task or "t2i" in args.task:
|
||||
if args.prompt is None:
|
||||
@ -399,20 +366,20 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
profiler=profiler,
|
||||
)
|
||||
with ChronoInspector("Creating WanT2V pipeline"):
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info("Warming up WanT2V pipeline ...")
|
||||
with torch.no_grad():
|
||||
with ChronoInspector("Warming up WanT2V pipeline"):
|
||||
_ = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
@ -425,16 +392,17 @@ def generate(args):
|
||||
offload_model=args.offload_model)
|
||||
|
||||
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
|
||||
video = wan_t2v.generate(
|
||||
args.prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
elif "i2v" in args.task:
|
||||
if args.prompt is None:
|
||||
@ -469,20 +437,20 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanI2V pipeline.")
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
profiler=profiler,
|
||||
)
|
||||
with ChronoInspector("Creating WanI2V pipeline"):
|
||||
wan_i2v = wan.WanI2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info("Warming up WanI2V pipeline ...")
|
||||
with torch.no_grad():
|
||||
with ChronoInspector("Warming up WanI2V pipeline"):
|
||||
_ = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
@ -496,17 +464,18 @@ def generate(args):
|
||||
offload_model=args.offload_model)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
with ChronoInspector("Generating video"):
|
||||
video = wan_i2v.generate(
|
||||
args.prompt,
|
||||
img,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
elif "flf2v" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
@ -542,20 +511,20 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating WanFLF2V pipeline.")
|
||||
wan_flf2v = wan.WanFLF2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
profiler=profiler
|
||||
)
|
||||
with ChronoInspector("Creating WanFLF2V pipeline"):
|
||||
wan_flf2v = wan.WanFLF2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
logging.info("Warming up WanFLF2V pipeline ...")
|
||||
with torch.no_grad():
|
||||
with ChronoInspector("Warming up WanFLF2V pipeline"):
|
||||
_ = wan_flf2v.generate(
|
||||
args.prompt,
|
||||
first_frame,
|
||||
@ -570,18 +539,19 @@ def generate(args):
|
||||
offload_model=args.offload_model)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
video = wan_flf2v.generate(
|
||||
args.prompt,
|
||||
first_frame,
|
||||
last_frame,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
with ChronoInspector("Generating video"):
|
||||
video = wan_flf2v.generate(
|
||||
args.prompt,
|
||||
first_frame,
|
||||
last_frame,
|
||||
max_area=MAX_AREA_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
elif "vace" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
@ -606,17 +576,17 @@ def generate(args):
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating VACE pipeline.")
|
||||
wan_vace = wan.WanVace(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
profiler=profiler
|
||||
)
|
||||
with ChronoInspector("Creating VACE pipeline"):
|
||||
wan_vace = wan.WanVace(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
|
||||
[args.src_video], [args.src_mask], [
|
||||
@ -625,8 +595,8 @@ def generate(args):
|
||||
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||
|
||||
logging.info("Warming up VACE pipeline ...")
|
||||
with torch.no_grad():
|
||||
_ = wan_vace.generate(
|
||||
with ChronoInspector("Warming up VACE pipeline"):
|
||||
video = wan_vace.generate(
|
||||
args.prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
@ -641,25 +611,23 @@ def generate(args):
|
||||
offload_model=args.offload_model)
|
||||
|
||||
logging.info(f"Generating video...")
|
||||
video = wan_vace.generate(
|
||||
args.prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
with ChronoInspector("Generating video"):
|
||||
video = wan_vace.generate(
|
||||
args.prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
raise ValueError(f"Unkown task type: {args.task}")
|
||||
|
||||
if args.profile and rank == 0:
|
||||
_finalize_profiler(profiler)
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -55,7 +54,6 @@ class WanFLF2V:
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
profiler=None,
|
||||
):
|
||||
r"""
|
||||
Initializes the image-to-video generation model components.
|
||||
@ -145,7 +143,6 @@ class WanFLF2V:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
self.profiler = profiler
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
@ -200,14 +197,6 @@ class WanFLF2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
first_frame_size = first_frame.size
|
||||
last_frame_size = last_frame.size
|
||||
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
||||
@ -300,10 +289,6 @@ class WanFLF2V:
|
||||
])[0]
|
||||
y = torch.concat([msk, y])
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -354,14 +339,8 @@ class WanFLF2V:
|
||||
if offload_model:
|
||||
empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.step()
|
||||
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
|
||||
@ -391,22 +370,15 @@ class WanFLF2V:
|
||||
generator=seed_g)[0]
|
||||
latent = temp_x0.squeeze(0)
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.vae.decode(x0)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
del noise, latent
|
||||
del sample_scheduler
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -55,7 +54,6 @@ class WanI2V:
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
profiler=None,
|
||||
):
|
||||
r"""
|
||||
Initializes the image-to-video generation model components.
|
||||
@ -145,7 +143,6 @@ class WanI2V:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
self.profiler = profiler
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
@ -195,14 +192,6 @@ class WanI2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||
|
||||
F = frame_num
|
||||
@ -273,10 +262,6 @@ class WanI2V:
|
||||
])[0]
|
||||
y = torch.concat([msk, y])
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -327,14 +312,8 @@ class WanI2V:
|
||||
if offload_model:
|
||||
empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.step()
|
||||
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
|
||||
@ -364,22 +343,15 @@ class WanI2V:
|
||||
generator=seed_g)[0]
|
||||
latent = temp_x0.squeeze(0)
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
x0 = [latent.to(self.device)]
|
||||
del latent_model_input, timestep
|
||||
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.vae.decode(x0)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
del noise, latent
|
||||
del sample_scheduler
|
||||
|
||||
@ -16,7 +16,7 @@ try:
|
||||
from wan.modules.attention import attention as flash_attention
|
||||
torch.backends.mudnn.allow_tf32 = True
|
||||
except ModuleNotFoundError:
|
||||
torch_musa = None
|
||||
pass
|
||||
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -51,7 +50,6 @@ class WanT2V:
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
profiler=None,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
@ -126,7 +124,6 @@ class WanT2V:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
self.profiler = profiler
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
@ -172,14 +169,6 @@ class WanT2V:
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
# preprocess
|
||||
F = frame_num
|
||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
@ -219,10 +208,6 @@ class WanT2V:
|
||||
generator=seed_g)
|
||||
]
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -259,14 +244,8 @@ class WanT2V:
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.step()
|
||||
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
@ -288,19 +267,12 @@ class WanT2V:
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.vae.decode(x0)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -8,7 +8,7 @@ except ModuleNotFoundError:
|
||||
torch_musa = None
|
||||
|
||||
|
||||
def _is_musa() -> bool:
|
||||
def _is_musa():
|
||||
if torch_musa is None:
|
||||
return False
|
||||
else:
|
||||
@ -48,14 +48,3 @@ def get_torch_distributed_backend() -> str:
|
||||
return "mccl"
|
||||
else:
|
||||
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
|
||||
|
||||
|
||||
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
|
||||
activities: List[torch.profiler.ProfilerActivity] = [
|
||||
torch.profiler.ProfilerActivity.CPU
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
elif _is_musa():
|
||||
activities.append(torch.profiler.ProfilerActivity.MUSA)
|
||||
return activities
|
||||
|
||||
31
wan/vace.py
31
wan/vace.py
@ -8,7 +8,6 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from time import perf_counter
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
@ -59,7 +58,6 @@ class WanVace(WanT2V):
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
profiler=None,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
@ -152,8 +150,6 @@ class WanVace(WanT2V):
|
||||
seq_len=75600,
|
||||
keep_last=True)
|
||||
|
||||
self.profiler = profiler
|
||||
|
||||
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||
vae = self.vae if vae is None else vae
|
||||
if ref_images is None:
|
||||
@ -358,14 +354,6 @@ class WanVace(WanT2V):
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.start()
|
||||
|
||||
start_time = 0.0
|
||||
end_time = 0.0
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
# preprocess
|
||||
# F = frame_num
|
||||
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
@ -416,10 +404,6 @@ class WanVace(WanT2V):
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size) * self.sp_size
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -456,19 +440,13 @@ class WanVace(WanT2V):
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
if self.profiler and self.rank == 0:
|
||||
self.profiler.step()
|
||||
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input,
|
||||
t=timestep,
|
||||
@ -493,19 +471,12 @@ class WanVace(WanT2V):
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
if self.rank == 0:
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
empty_cache()
|
||||
if self.rank == 0:
|
||||
start_time = perf_counter()
|
||||
videos = self.decode_latent(x0, input_ref_images)
|
||||
end_time = perf_counter()
|
||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
|
||||
Loading…
Reference in New Issue
Block a user