mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
isort the code
This commit is contained in:
parent
c709fcf0e7
commit
33daae0941
19
generate.py
19
generate.py
@ -1,21 +1,23 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import torch, random
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
import wan
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.utils import cache_video, cache_image, str2bool
|
||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-1.3B": {
|
||||
@ -281,8 +283,10 @@ def generate(args):
|
||||
|
||||
if args.ulysses_size > 1 or args.ring_size > 1:
|
||||
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
from xfuser.core.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
@ -485,6 +489,7 @@ def generate(args):
|
||||
offload_model=args.offload_model
|
||||
)
|
||||
elif "vace" in args.task:
|
||||
torch.cuda.memory._record_memory_history(max_entries=1000000)
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||
@ -564,6 +569,8 @@ def generate(args):
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
torch.cuda.memory._dump_snapshot(f"memory.pickle")
|
||||
logging.info("Finished.")
|
||||
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import gc
|
||||
import os.path as osp
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import gc
|
||||
import os.path as osp
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
|
@ -2,18 +2,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import gradio as gr
|
||||
|
||||
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||
import wan
|
||||
from wan import WanVace, WanVaceMP
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
||||
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
||||
|
||||
|
||||
class FixedSizeQueue:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from . import configs, distributed, modules
|
||||
from .first_last_frame2video import WanFLF2V
|
||||
from .image2video import WanI2V
|
||||
from .text2video import WanT2V
|
||||
from .first_last_frame2video import WanFLF2V
|
||||
from .vace import WanVace, WanVaceMP
|
||||
|
@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||
from torch.distributed.utils import _free_storage
|
||||
|
||||
|
||||
def shard_model(
|
||||
model,
|
||||
device_id,
|
||||
|
@ -1,9 +1,11 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.distributed import (
|
||||
get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group,
|
||||
)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
from ..modules.model import sinusoidal_embedding_1d
|
||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from .utils.fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
|
||||
@ -103,11 +106,12 @@ class WanFLF2V:
|
||||
init_on_cpu = False
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward)
|
||||
from .distributed.xdit_context_parallel import (
|
||||
usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
|
@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from .utils.fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
|
||||
@ -103,11 +106,12 @@ class WanI2V:
|
||||
init_on_cpu = False
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward)
|
||||
from .distributed.xdit_context_parallel import (
|
||||
usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
|
@ -3,7 +3,8 @@ import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
||||
|
||||
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
|
@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from .utils.fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
|
||||
|
||||
@ -85,11 +88,12 @@ class WanT2V:
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward)
|
||||
from .distributed.xdit_context_parallel import (
|
||||
usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
|
@ -1,5 +1,8 @@
|
||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
retrieve_timesteps)
|
||||
from .fm_solvers import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .vace_processor import VaceVideoProcessor
|
||||
|
||||
|
@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
||||
SchedulerMixin,
|
||||
SchedulerOutput)
|
||||
from diffusers.schedulers.scheduling_utils import (
|
||||
KarrasDiffusionSchedulers,
|
||||
SchedulerMixin,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from diffusers.utils import deprecate, is_scipy_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
||||
SchedulerMixin,
|
||||
SchedulerOutput)
|
||||
from diffusers.schedulers.scheduling_utils import (
|
||||
KarrasDiffusionSchedulers,
|
||||
SchedulerMixin,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from diffusers.utils import deprecate, is_scipy_available
|
||||
|
||||
if is_scipy_available():
|
||||
|
@ -7,7 +7,7 @@ import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union, List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import dashscope
|
||||
import torch
|
||||
@ -393,8 +393,11 @@ class QwenPromptExpander(PromptExpander):
|
||||
|
||||
if self.is_vl:
|
||||
# default: Load the model on the available device(s)
|
||||
from transformers import (AutoProcessor, AutoTokenizer,
|
||||
Qwen2_5_VLForConditionalGeneration)
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
try:
|
||||
from .qwen_vl_utils import process_vision_info
|
||||
except:
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class VaceImageProcessor(object):
|
||||
|
56
wan/vace.py
56
wan/vace.py
@ -1,28 +1,36 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
import types
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms.functional as TF
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
||||
from .modules.vace_model import VaceWanModel
|
||||
from .text2video import (
|
||||
FlowDPMSolverMultistepScheduler,
|
||||
FlowUniPCMultistepScheduler,
|
||||
T5EncoderModel,
|
||||
WanT2V,
|
||||
WanVAE,
|
||||
get_sampling_sigmas,
|
||||
retrieve_timesteps,
|
||||
shard_model,
|
||||
)
|
||||
from .utils.vace_processor import VaceVideoProcessor
|
||||
|
||||
|
||||
@ -87,12 +95,13 @@ class WanVace(WanT2V):
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
from .distributed.xdit_context_parallel import (
|
||||
usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace,
|
||||
)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
@ -514,8 +523,10 @@ class WanVaceMP(WanVace):
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
from xfuser.core.distributed import (
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
@ -547,9 +558,12 @@ class WanVaceMP(WanVace):
|
||||
|
||||
if self.use_usp:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
|
||||
from .distributed.xdit_context_parallel import (
|
||||
usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace,
|
||||
)
|
||||
for block in model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
|
Loading…
Reference in New Issue
Block a user