isort the code

This commit is contained in:
wangang.wa 2025-05-15 16:41:07 +08:00
parent c709fcf0e7
commit 33daae0941
20 changed files with 118 additions and 69 deletions

View File

@ -1,21 +1,23 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
from datetime import datetime
import logging import logging
import os import os
import sys import sys
import warnings import warnings
from datetime import datetime
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch, random import random
import torch
import torch.distributed as dist import torch.distributed as dist
from PIL import Image from PIL import Image
import wan import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = { EXAMPLE_PROMPT = {
"t2v-1.3B": { "t2v-1.3B": {
@ -281,8 +283,10 @@ def generate(args):
if args.ulysses_size > 1 or args.ring_size > 1: if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel, from xfuser.core.distributed import (
init_distributed_environment) init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment( init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()) rank=dist.get_rank(), world_size=dist.get_world_size())
@ -485,6 +489,7 @@ def generate(args):
offload_model=args.offload_model offload_model=args.offload_model
) )
elif "vace" in args.task: elif "vace" in args.task:
torch.cuda.memory._record_memory_history(max_entries=1000000)
if args.prompt is None: if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None) args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
@ -564,6 +569,8 @@ def generate(args):
nrow=1, nrow=1,
normalize=True, normalize=True,
value_range=(-1, 1)) value_range=(-1, 1))
torch.cuda.memory._record_memory_history(enabled=None)
torch.cuda.memory._dump_snapshot(f"memory.pickle")
logging.info("Finished.") logging.info("Finished.")

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import gc import gc
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import gc import gc
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse import argparse
import os.path as osp
import os import os
import os.path as osp
import sys import sys
import warnings import warnings

View File

@ -2,18 +2,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import argparse import argparse
import datetime
import os import os
import sys import sys
import datetime
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
import gradio as gr import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan import wan
from wan import WanVace, WanVaceMP from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue: class FixedSizeQueue:

View File

@ -1,5 +1,5 @@
from . import configs, distributed, modules from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V from .image2video import WanI2V
from .text2video import WanT2V from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP from .vace import WanVace, WanVaceMP

View File

@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage from torch.distributed.utils import _free_storage
def shard_model( def shard_model(
model, model,
device_id, device_id,

View File

@ -1,9 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank, from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size, get_sequence_parallel_world_size,
get_sp_group) get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d from ..modules.model import sinusoidal_embedding_1d

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanFLF2V:
init_on_cpu = False init_on_cpu = False
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanI2V:
init_on_cpu = False init_on_cpu = False
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)

View File

@ -3,7 +3,8 @@ import torch
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.nn as nn import torch.nn as nn
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock): class VaceWanAttentionBlock(WanAttentionBlock):

View File

@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
from .modules.model import WanModel from .modules.model import WanModel
from .modules.t5 import T5EncoderModel from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, from .utils.fm_solvers import (
get_sampling_sigmas, retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -85,11 +88,12 @@ class WanT2V:
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_dit_forward) usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)

View File

@ -1,5 +1,8 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, from .fm_solvers import (
retrieve_timesteps) FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor from .vace_processor import VaceVideoProcessor

View File

@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin, SchedulerMixin,
SchedulerOutput) SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor

View File

@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin, SchedulerMixin,
SchedulerOutput) SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available(): if is_scipy_available():

View File

@ -7,7 +7,7 @@ import sys
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Union, List from typing import List, Optional, Union
import dashscope import dashscope
import torch import torch
@ -393,8 +393,11 @@ class QwenPromptExpander(PromptExpander):
if self.is_vl: if self.is_vl:
# default: Load the model on the available device(s) # default: Load the model on the available device(s)
from transformers import (AutoProcessor, AutoTokenizer, from transformers import (
Qwen2_5_VLForConditionalGeneration) AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try: try:
from .qwen_vl_utils import process_vision_info from .qwen_vl_utils import process_vision_info
except: except:

View File

@ -1,9 +1,9 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np import numpy as np
from PIL import Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms.functional as TF import torchvision.transforms.functional as TF
from PIL import Image
class VaceImageProcessor(object): class VaceImageProcessor(object):

View File

@ -1,28 +1,36 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc import gc
import math
import time
import random
import types
import logging import logging
import math
import os
import random
import sys
import time
import traceback import traceback
import types
from contextlib import contextmanager from contextlib import contextmanager
from functools import partial from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch import torch
import torch.nn.functional as F
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
WanT2V,
WanVAE,
get_sampling_sigmas,
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor from .utils.vace_processor import VaceVideoProcessor
@ -87,12 +95,13 @@ class WanVace(WanT2V):
self.model.eval().requires_grad_(False) self.model.eval().requires_grad_(False)
if use_usp: if use_usp:
from xfuser.core.distributed import \ from xfuser.core.distributed import get_sequence_parallel_world_size
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward, from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward, usp_dit_forward,
usp_dit_forward_vace) usp_dit_forward_vace,
)
for block in self.model.blocks: for block in self.model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)
@ -514,8 +523,10 @@ class WanVaceMP(WanVace):
world_size=world_size world_size=world_size
) )
from xfuser.core.distributed import (initialize_model_parallel, from xfuser.core.distributed import (
init_distributed_environment) init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment( init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size()) rank=dist.get_rank(), world_size=dist.get_world_size())
@ -547,9 +558,12 @@ class WanVaceMP(WanVace):
if self.use_usp: if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward, usp_dit_forward,
usp_dit_forward_vace) usp_dit_forward_vace,
)
for block in model.blocks: for block in model.blocks:
block.self_attn.forward = types.MethodType( block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn) usp_attn_forward, block.self_attn)