isort the code

This commit is contained in:
wangang.wa 2025-05-15 16:41:07 +08:00
parent c709fcf0e7
commit 33daae0941
20 changed files with 118 additions and 69 deletions

View File

@ -1,21 +1,23 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
from datetime import datetime
warnings.filterwarnings('ignore')
import torch, random
import random
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video, cache_image, str2bool
from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
@ -281,8 +283,10 @@ def generate(args):
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
@ -485,6 +489,7 @@ def generate(args):
offload_model=args.offload_model
)
elif "vace" in args.task:
torch.cuda.memory._record_memory_history(max_entries=1000000)
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
@ -564,6 +569,8 @@ def generate(args):
nrow=1,
normalize=True,
value_range=(-1, 1))
torch.cuda.memory._record_memory_history(enabled=None)
torch.cuda.memory._dump_snapshot(f"memory.pickle")
logging.info("Finished.")

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings

View File

@ -1,8 +1,8 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import os.path as osp
import sys
import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings

View File

@ -1,7 +1,7 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os.path as osp
import os
import os.path as osp
import sys
import warnings

View File

@ -2,18 +2,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import datetime
import os
import sys
import datetime
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
class FixedSizeQueue:

View File

@ -1,5 +1,5 @@
from . import configs, distributed, modules
from .first_last_frame2video import WanFLF2V
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP

View File

@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,

View File

@ -1,9 +1,11 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group)
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from ..modules.model import sinusoidal_embedding_1d

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanFLF2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)

View File

@ -21,8 +21,11 @@ from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -103,11 +106,12 @@ class WanI2V:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)

View File

@ -3,7 +3,8 @@ import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock):

View File

@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
@ -85,11 +88,12 @@ class WanT2V:
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)

View File

@ -1,5 +1,8 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor

View File

@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
from diffusers.utils.torch_utils import randn_tensor

View File

@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput)
from diffusers.schedulers.scheduling_utils import (
KarrasDiffusionSchedulers,
SchedulerMixin,
SchedulerOutput,
)
from diffusers.utils import deprecate, is_scipy_available
if is_scipy_available():

View File

@ -7,7 +7,7 @@ import sys
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union, List
from typing import List, Optional, Union
import dashscope
import torch
@ -393,8 +393,11 @@ class QwenPromptExpander(PromptExpander):
if self.is_vl:
# default: Load the model on the available device(s)
from transformers import (AutoProcessor, AutoTokenizer,
Qwen2_5_VLForConditionalGeneration)
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
)
try:
from .qwen_vl_utils import process_vision_info
except:

View File

@ -1,9 +1,9 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
class VaceImageProcessor(object):

View File

@ -1,28 +1,36 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc
import math
import time
import random
import types
import logging
import math
import os
import random
import sys
import time
import traceback
import types
from contextlib import contextmanager
from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel
from .text2video import (
FlowDPMSolverMultistepScheduler,
FlowUniPCMultistepScheduler,
T5EncoderModel,
WanT2V,
WanVAE,
get_sampling_sigmas,
retrieve_timesteps,
shard_model,
)
from .utils.vace_processor import VaceVideoProcessor
@ -87,12 +95,13 @@ class WanVace(WanT2V):
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
@ -514,8 +523,10 @@ class WanVaceMP(WanVace):
world_size=world_size
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
from xfuser.core.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
@ -547,9 +558,12 @@ class WanVaceMP(WanVace):
if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace,
)
for block in model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)