mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-21 06:32:07 +00:00
Compare commits
2 Commits
a76ccc9d36
...
22f4714dc1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22f4714dc1 | ||
|
|
28919e5663 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -21,7 +21,7 @@
|
|||||||
*.html
|
*.html
|
||||||
*.pdf
|
*.pdf
|
||||||
*.whl
|
*.whl
|
||||||
*cache/
|
cache
|
||||||
__pycache__/
|
__pycache__/
|
||||||
storage/
|
storage/
|
||||||
samples/
|
samples/
|
||||||
@ -29,11 +29,9 @@ samples/
|
|||||||
!requirements.txt
|
!requirements.txt
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*DS_Store
|
*DS_Store
|
||||||
.vscode
|
|
||||||
google/
|
google/
|
||||||
Wan2.1-T2V-14B/
|
Wan2.1-T2V-14B/
|
||||||
Wan2.1-T2V-1.3B/
|
Wan2.1-T2V-1.3B/
|
||||||
Wan2.1-I2V-14B-480P/
|
Wan2.1-I2V-14B-480P/
|
||||||
Wan2.1-I2V-14B-720P/
|
Wan2.1-I2V-14B-720P/
|
||||||
poetry.lock
|
poetry.lock
|
||||||
logs/
|
|
||||||
|
|||||||
62
generate.py
62
generate.py
@ -25,12 +25,8 @@ import wan
|
|||||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
from wan.utils.platform import (
|
from wan.utils.platform import get_torch_distributed_backend
|
||||||
get_device_type,
|
from wan.utils.chrono_inspector import ChronoInspector
|
||||||
get_torch_distributed_backend,
|
|
||||||
get_torch_profiler_activities,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EXAMPLE_PROMPT = {
|
EXAMPLE_PROMPT = {
|
||||||
@ -256,11 +252,6 @@ def _parse_args():
|
|||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=5.0,
|
||||||
help="Classifier free guidance scale.")
|
help="Classifier free guidance scale.")
|
||||||
parser.add_argument(
|
|
||||||
"--profile",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="profile the generating procedure.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -281,26 +272,6 @@ def _init_logging(rank):
|
|||||||
logging.basicConfig(level=logging.ERROR)
|
logging.basicConfig(level=logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
def _init_profiler():
|
|
||||||
return torch.profiler.profile(
|
|
||||||
activities=get_torch_profiler_activities(),
|
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
|
|
||||||
record_shapes=True,
|
|
||||||
profile_memory=True,
|
|
||||||
with_stack=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _finalize_profiler(profiler):
|
|
||||||
profiler.stop()
|
|
||||||
table = profiler.key_averages().table(
|
|
||||||
sort_by=f"{get_device_type()}_time_total",
|
|
||||||
row_limit=20,
|
|
||||||
)
|
|
||||||
with open(f"logs/{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}-profiling.txt", "w") as f:
|
|
||||||
f.write(table)
|
|
||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
rank = int(os.getenv("RANK", 0))
|
rank = int(os.getenv("RANK", 0))
|
||||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||||
@ -368,10 +339,6 @@ def generate(args):
|
|||||||
dist.broadcast_object_list(base_seed, src=0)
|
dist.broadcast_object_list(base_seed, src=0)
|
||||||
args.base_seed = base_seed[0]
|
args.base_seed = base_seed[0]
|
||||||
|
|
||||||
profiler = None
|
|
||||||
if args.profile and rank == 0:
|
|
||||||
profiler = _init_profiler()
|
|
||||||
|
|
||||||
if "t2v" in args.task or "t2i" in args.task:
|
if "t2v" in args.task or "t2i" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
@ -399,6 +366,7 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanT2V pipeline.")
|
logging.info("Creating WanT2V pipeline.")
|
||||||
|
with ChronoInspector("Creating WanT2V pipeline"):
|
||||||
wan_t2v = wan.WanT2V(
|
wan_t2v = wan.WanT2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -408,11 +376,10 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
profiler=profiler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanT2V pipeline ...")
|
logging.info("Warming up WanT2V pipeline ...")
|
||||||
with torch.no_grad():
|
with ChronoInspector("Warming up WanT2V pipeline"):
|
||||||
_ = wan_t2v.generate(
|
_ = wan_t2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
size=SIZE_CONFIGS[args.size],
|
size=SIZE_CONFIGS[args.size],
|
||||||
@ -425,6 +392,7 @@ def generate(args):
|
|||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
|
||||||
|
with ChronoInspector(f"Generating {'image' if 't2i' in args.task else 'video'}"):
|
||||||
video = wan_t2v.generate(
|
video = wan_t2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
size=SIZE_CONFIGS[args.size],
|
size=SIZE_CONFIGS[args.size],
|
||||||
@ -469,6 +437,7 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanI2V pipeline.")
|
logging.info("Creating WanI2V pipeline.")
|
||||||
|
with ChronoInspector("Creating WanI2V pipeline"):
|
||||||
wan_i2v = wan.WanI2V(
|
wan_i2v = wan.WanI2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -478,11 +447,10 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
profiler=profiler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanI2V pipeline ...")
|
logging.info("Warming up WanI2V pipeline ...")
|
||||||
with torch.no_grad():
|
with ChronoInspector("Warming up WanI2V pipeline"):
|
||||||
_ = wan_i2v.generate(
|
_ = wan_i2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
img,
|
img,
|
||||||
@ -496,6 +464,7 @@ def generate(args):
|
|||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
|
with ChronoInspector("Generating video"):
|
||||||
video = wan_i2v.generate(
|
video = wan_i2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
img,
|
img,
|
||||||
@ -542,6 +511,7 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanFLF2V pipeline.")
|
logging.info("Creating WanFLF2V pipeline.")
|
||||||
|
with ChronoInspector("Creating WanFLF2V pipeline"):
|
||||||
wan_flf2v = wan.WanFLF2V(
|
wan_flf2v = wan.WanFLF2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -551,11 +521,10 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
profiler=profiler
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Warming up WanFLF2V pipeline ...")
|
logging.info("Warming up WanFLF2V pipeline ...")
|
||||||
with torch.no_grad():
|
with ChronoInspector("Warming up WanFLF2V pipeline"):
|
||||||
_ = wan_flf2v.generate(
|
_ = wan_flf2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
first_frame,
|
first_frame,
|
||||||
@ -570,6 +539,7 @@ def generate(args):
|
|||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
logging.info("Generating video ...")
|
logging.info("Generating video ...")
|
||||||
|
with ChronoInspector("Generating video"):
|
||||||
video = wan_flf2v.generate(
|
video = wan_flf2v.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
first_frame,
|
first_frame,
|
||||||
@ -606,6 +576,7 @@ def generate(args):
|
|||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating VACE pipeline.")
|
logging.info("Creating VACE pipeline.")
|
||||||
|
with ChronoInspector("Creating VACE pipeline"):
|
||||||
wan_vace = wan.WanVace(
|
wan_vace = wan.WanVace(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
@ -615,7 +586,6 @@ def generate(args):
|
|||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
profiler=profiler
|
|
||||||
)
|
)
|
||||||
|
|
||||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
|
src_video, src_mask, src_ref_images = wan_vace.prepare_source(
|
||||||
@ -625,8 +595,8 @@ def generate(args):
|
|||||||
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
], args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||||
|
|
||||||
logging.info("Warming up VACE pipeline ...")
|
logging.info("Warming up VACE pipeline ...")
|
||||||
with torch.no_grad():
|
with ChronoInspector("Warming up VACE pipeline"):
|
||||||
_ = wan_vace.generate(
|
video = wan_vace.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
src_video,
|
src_video,
|
||||||
src_mask,
|
src_mask,
|
||||||
@ -641,6 +611,7 @@ def generate(args):
|
|||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
logging.info(f"Generating video...")
|
logging.info(f"Generating video...")
|
||||||
|
with ChronoInspector("Generating video"):
|
||||||
video = wan_vace.generate(
|
video = wan_vace.generate(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
src_video,
|
src_video,
|
||||||
@ -657,9 +628,6 @@ def generate(args):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unkown task type: {args.task}")
|
raise ValueError(f"Unkown task type: {args.task}")
|
||||||
|
|
||||||
if args.profile and rank == 0:
|
|
||||||
_finalize_profiler(profiler)
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.save_file is None:
|
if args.save_file is None:
|
||||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from time import perf_counter
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -55,7 +54,6 @@ class WanFLF2V:
|
|||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
init_on_cpu=True,
|
init_on_cpu=True,
|
||||||
profiler=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the image-to-video generation model components.
|
Initializes the image-to-video generation model components.
|
||||||
@ -145,7 +143,6 @@ class WanFLF2V:
|
|||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
self.profiler = profiler
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
@ -200,14 +197,6 @@ class WanFLF2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
|
||||||
end_time = 0.0
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
first_frame_size = first_frame.size
|
first_frame_size = first_frame.size
|
||||||
last_frame_size = last_frame.size
|
last_frame_size = last_frame.size
|
||||||
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
|
||||||
@ -300,10 +289,6 @@ class WanFLF2V:
|
|||||||
])[0]
|
])[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -354,14 +339,8 @@ class WanFLF2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.step()
|
|
||||||
|
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
@ -391,10 +370,6 @@ class WanFLF2V:
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latent = temp_x0.squeeze(0)
|
latent = temp_x0.squeeze(0)
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device)]
|
x0 = [latent.to(self.device)]
|
||||||
del latent_model_input, timestep
|
del latent_model_input, timestep
|
||||||
|
|
||||||
@ -403,10 +378,7 @@ class WanFLF2V:
|
|||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
start_time = perf_counter()
|
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from time import perf_counter
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -55,7 +54,6 @@ class WanI2V:
|
|||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
init_on_cpu=True,
|
init_on_cpu=True,
|
||||||
profiler=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the image-to-video generation model components.
|
Initializes the image-to-video generation model components.
|
||||||
@ -145,7 +143,6 @@ class WanI2V:
|
|||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
self.profiler = profiler
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
@ -195,14 +192,6 @@ class WanI2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
|
||||||
end_time = 0.0
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
|
||||||
F = frame_num
|
F = frame_num
|
||||||
@ -273,10 +262,6 @@ class WanI2V:
|
|||||||
])[0]
|
])[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[preprocess and VAE encode] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -327,14 +312,8 @@ class WanI2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.step()
|
|
||||||
|
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
@ -364,10 +343,6 @@ class WanI2V:
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latent = temp_x0.squeeze(0)
|
latent = temp_x0.squeeze(0)
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device)]
|
x0 = [latent.to(self.device)]
|
||||||
del latent_model_input, timestep
|
del latent_model_input, timestep
|
||||||
|
|
||||||
@ -376,10 +351,7 @@ class WanI2V:
|
|||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
start_time = perf_counter()
|
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
del noise, latent
|
del noise, latent
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
|||||||
@ -16,7 +16,7 @@ try:
|
|||||||
from wan.modules.attention import attention as flash_attention
|
from wan.modules.attention import attention as flash_attention
|
||||||
torch.backends.mudnn.allow_tf32 = True
|
torch.backends.mudnn.allow_tf32 = True
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
torch_musa = None
|
pass
|
||||||
|
|
||||||
__all__ = ['WanModel']
|
__all__ = ['WanModel']
|
||||||
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from time import perf_counter
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -51,7 +50,6 @@ class WanT2V:
|
|||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
profiler=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the Wan text-to-video generation model components.
|
Initializes the Wan text-to-video generation model components.
|
||||||
@ -126,7 +124,6 @@ class WanT2V:
|
|||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
|
|
||||||
self.sample_neg_prompt = config.sample_neg_prompt
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
self.profiler = profiler
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
input_prompt,
|
input_prompt,
|
||||||
@ -172,14 +169,6 @@ class WanT2V:
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
|
||||||
end_time = 0.0
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
# preprocess
|
# preprocess
|
||||||
F = frame_num
|
F = frame_num
|
||||||
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||||
@ -219,10 +208,6 @@ class WanT2V:
|
|||||||
generator=seed_g)
|
generator=seed_g)
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -259,14 +244,8 @@ class WanT2V:
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.step()
|
|
||||||
|
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
@ -288,19 +267,12 @@ class WanT2V:
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latents = [temp_x0.squeeze(0)]
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
empty_cache()
|
empty_cache()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
start_time = perf_counter()
|
|
||||||
videos = self.vae.decode(x0)
|
videos = self.vae.decode(x0)
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
del noise, latents
|
del noise, latents
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,7 +8,7 @@ except ModuleNotFoundError:
|
|||||||
torch_musa = None
|
torch_musa = None
|
||||||
|
|
||||||
|
|
||||||
def _is_musa() -> bool:
|
def _is_musa():
|
||||||
if torch_musa is None:
|
if torch_musa is None:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@ -48,14 +48,3 @@ def get_torch_distributed_backend() -> str:
|
|||||||
return "mccl"
|
return "mccl"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
|
raise NotImplementedError("No Accelerators(NV/MTT GPU) available")
|
||||||
|
|
||||||
|
|
||||||
def get_torch_profiler_activities() -> List[torch.profiler.ProfilerActivity]:
|
|
||||||
activities: List[torch.profiler.ProfilerActivity] = [
|
|
||||||
torch.profiler.ProfilerActivity.CPU
|
|
||||||
]
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
|
||||||
elif _is_musa():
|
|
||||||
activities.append(torch.profiler.ProfilerActivity.MUSA)
|
|
||||||
return activities
|
|
||||||
|
|||||||
31
wan/vace.py
31
wan/vace.py
@ -8,7 +8,6 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
from time import perf_counter
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -59,7 +58,6 @@ class WanVace(WanT2V):
|
|||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
profiler=None,
|
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Initializes the Wan text-to-video generation model components.
|
Initializes the Wan text-to-video generation model components.
|
||||||
@ -152,8 +150,6 @@ class WanVace(WanT2V):
|
|||||||
seq_len=75600,
|
seq_len=75600,
|
||||||
keep_last=True)
|
keep_last=True)
|
||||||
|
|
||||||
self.profiler = profiler
|
|
||||||
|
|
||||||
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||||
vae = self.vae if vae is None else vae
|
vae = self.vae if vae is None else vae
|
||||||
if ref_images is None:
|
if ref_images is None:
|
||||||
@ -358,14 +354,6 @@ class WanVace(WanT2V):
|
|||||||
- H: Frame height (from size)
|
- H: Frame height (from size)
|
||||||
- W: Frame width from size)
|
- W: Frame width from size)
|
||||||
"""
|
"""
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.start()
|
|
||||||
|
|
||||||
start_time = 0.0
|
|
||||||
end_time = 0.0
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
# preprocess
|
# preprocess
|
||||||
# F = frame_num
|
# F = frame_num
|
||||||
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||||
@ -416,10 +404,6 @@ class WanVace(WanT2V):
|
|||||||
(self.patch_size[1] * self.patch_size[2]) *
|
(self.patch_size[1] * self.patch_size[2]) *
|
||||||
target_shape[1] / self.sp_size) * self.sp_size
|
target_shape[1] / self.sp_size) * self.sp_size
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[preprocess] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def noop_no_sync():
|
def noop_no_sync():
|
||||||
yield
|
yield
|
||||||
@ -456,19 +440,13 @@ class WanVace(WanT2V):
|
|||||||
arg_c = {'context': context, 'seq_len': seq_len}
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
start_time = perf_counter()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
if self.profiler and self.rank == 0:
|
|
||||||
self.profiler.step()
|
|
||||||
|
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
timestep = [t]
|
timestep = [t]
|
||||||
|
|
||||||
timestep = torch.stack(timestep)
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
noise_pred_cond = self.model(
|
noise_pred_cond = self.model(
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
t=timestep,
|
t=timestep,
|
||||||
@ -493,19 +471,12 @@ class WanVace(WanT2V):
|
|||||||
generator=seed_g)[0]
|
generator=seed_g)[0]
|
||||||
latents = [temp_x0.squeeze(0)]
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
if self.rank == 0:
|
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[sampling time steps] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
self.model.cpu()
|
self.model.cpu()
|
||||||
empty_cache()
|
empty_cache()
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
start_time = perf_counter()
|
|
||||||
videos = self.decode_latent(x0, input_ref_images)
|
videos = self.decode_latent(x0, input_ref_images)
|
||||||
end_time = perf_counter()
|
|
||||||
logging.info(f"[VAE decoding] Elapsed time: {end_time - start_time:.2f} seconds")
|
|
||||||
|
|
||||||
del noise, latents
|
del noise, latents
|
||||||
del sample_scheduler
|
del sample_scheduler
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user