From 33daae0941193befb450a9bc1367085f41851435 Mon Sep 17 00:00:00 2001 From: "wangang.wa" Date: Thu, 15 May 2025 16:41:07 +0800 Subject: [PATCH] isort the code --- generate.py | 19 +++++--- gradio/fl2v_14B_singleGPU.py | 2 +- gradio/i2v_14B_singleGPU.py | 2 +- gradio/t2i_14B_singleGPU.py | 2 +- gradio/t2v_1.3B_singleGPU.py | 2 +- gradio/t2v_14B_singleGPU.py | 2 +- gradio/vace.py | 6 ++- wan/__init__.py | 2 +- wan/distributed/fsdp.py | 1 + wan/distributed/xdit_context_parallel.py | 8 ++-- wan/first_last_frame2video.py | 16 ++++--- wan/image2video.py | 16 ++++--- wan/modules/vace_model.py | 3 +- wan/text2video.py | 16 ++++--- wan/utils/__init__.py | 7 ++- wan/utils/fm_solvers.py | 8 ++-- wan/utils/fm_solvers_unipc.py | 8 ++-- wan/utils/prompt_extend.py | 9 ++-- wan/utils/vace_processor.py | 2 +- wan/vace.py | 56 +++++++++++++++--------- 20 files changed, 118 insertions(+), 69 deletions(-) diff --git a/generate.py b/generate.py index 2e6b35c..a0fafbc 100644 --- a/generate.py +++ b/generate.py @@ -1,21 +1,23 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -from datetime import datetime import logging import os import sys import warnings +from datetime import datetime warnings.filterwarnings('ignore') -import torch, random +import random + +import torch import torch.distributed as dist from PIL import Image import wan -from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander -from wan.utils.utils import cache_video, cache_image, str2bool +from wan.utils.utils import cache_image, cache_video, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { @@ -281,8 +283,10 @@ def generate(args): if args.ulysses_size > 1 or args.ring_size > 1: assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." - from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) + from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, + ) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) @@ -485,6 +489,7 @@ def generate(args): offload_model=args.offload_model ) elif "vace" in args.task: + torch.cuda.memory._record_memory_history(max_entries=1000000) if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None) @@ -564,6 +569,8 @@ def generate(args): nrow=1, normalize=True, value_range=(-1, 1)) + torch.cuda.memory._record_memory_history(enabled=None) + torch.cuda.memory._dump_snapshot(f"memory.pickle") logging.info("Finished.") diff --git a/gradio/fl2v_14B_singleGPU.py b/gradio/fl2v_14B_singleGPU.py index 476a136..346cb7e 100644 --- a/gradio/fl2v_14B_singleGPU.py +++ b/gradio/fl2v_14B_singleGPU.py @@ -1,8 +1,8 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc -import os.path as osp import os +import os.path as osp import sys import warnings diff --git a/gradio/i2v_14B_singleGPU.py b/gradio/i2v_14B_singleGPU.py index 35c1e08..959550f 100644 --- a/gradio/i2v_14B_singleGPU.py +++ b/gradio/i2v_14B_singleGPU.py @@ -1,8 +1,8 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc -import os.path as osp import os +import os.path as osp import sys import warnings diff --git a/gradio/t2i_14B_singleGPU.py b/gradio/t2i_14B_singleGPU.py index 1ccc229..541796f 100644 --- a/gradio/t2i_14B_singleGPU.py +++ b/gradio/t2i_14B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings diff --git a/gradio/t2v_1.3B_singleGPU.py b/gradio/t2v_1.3B_singleGPU.py index 987634b..865774f 100644 --- a/gradio/t2v_1.3B_singleGPU.py +++ b/gradio/t2v_1.3B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings diff --git a/gradio/t2v_14B_singleGPU.py b/gradio/t2v_14B_singleGPU.py index 37c11ae..667ff85 100644 --- a/gradio/t2v_14B_singleGPU.py +++ b/gradio/t2v_14B_singleGPU.py @@ -1,7 +1,7 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse -import os.path as osp import os +import os.path as osp import sys import warnings diff --git a/gradio/vace.py b/gradio/vace.py index 75f780a..85d3dc7 100644 --- a/gradio/vace.py +++ b/gradio/vace.py @@ -2,18 +2,20 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import argparse +import datetime import os import sys -import datetime + import imageio import numpy as np import torch + import gradio as gr sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan import WanVace, WanVaceMP -from wan.configs import WAN_CONFIGS, SIZE_CONFIGS +from wan.configs import SIZE_CONFIGS, WAN_CONFIGS class FixedSizeQueue: diff --git a/wan/__init__.py b/wan/__init__.py index 45d555d..afed024 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -1,5 +1,5 @@ from . import configs, distributed, modules +from .first_last_frame2video import WanFLF2V from .image2video import WanI2V from .text2video import WanT2V -from .first_last_frame2video import WanFLF2V from .vace import WanVace, WanVaceMP diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 18ba2f3..f4db5d3 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -8,6 +8,7 @@ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from torch.distributed.utils import _free_storage + def shard_model( model, device_id, diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index e0be6c7..87fd22a 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -1,9 +1,11 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.cuda.amp as amp -from xfuser.core.distributed import (get_sequence_parallel_rank, - get_sequence_parallel_world_size, - get_sp_group) +from xfuser.core.distributed import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) from xfuser.core.long_ctx_attention import xFuserLongContextAttention from ..modules.model import sinusoidal_embedding_1d diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py index 4f300ca..449b657 100644 --- a/wan/first_last_frame2video.py +++ b/wan/first_last_frame2video.py @@ -21,8 +21,11 @@ from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -103,11 +106,12 @@ class WanFLF2V: init_on_cpu = False if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) diff --git a/wan/image2video.py b/wan/image2video.py index 5004f46..90cd682 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -21,8 +21,11 @@ from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -103,11 +106,12 @@ class WanI2V: init_on_cpu = False if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) diff --git a/wan/modules/vace_model.py b/wan/modules/vace_model.py index 60178a9..46edad5 100644 --- a/wan/modules/vace_model.py +++ b/wan/modules/vace_model.py @@ -3,7 +3,8 @@ import torch import torch.cuda.amp as amp import torch.nn as nn from diffusers.configuration_utils import register_to_config -from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d + +from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d class VaceWanAttentionBlock(WanAttentionBlock): diff --git a/wan/text2video.py b/wan/text2video.py index 2400545..c518b61 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -18,8 +18,11 @@ from .distributed.fsdp import shard_model from .modules.model import WanModel from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler @@ -85,11 +88,12 @@ class WanT2V: self.model.eval().requires_grad_(False) if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) diff --git a/wan/utils/__init__.py b/wan/utils/__init__.py index ba3fe7d..2e9b33d 100644 --- a/wan/utils/__init__.py +++ b/wan/utils/__init__.py @@ -1,5 +1,8 @@ -from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, - retrieve_timesteps) +from .fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) from .fm_solvers_unipc import FlowUniPCMultistepScheduler from .vace_processor import VaceVideoProcessor diff --git a/wan/utils/fm_solvers.py b/wan/utils/fm_solvers.py index c908969..17bef85 100644 --- a/wan/utils/fm_solvers.py +++ b/wan/utils/fm_solvers.py @@ -9,9 +9,11 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput) +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) from diffusers.utils import deprecate, is_scipy_available from diffusers.utils.torch_utils import randn_tensor diff --git a/wan/utils/fm_solvers_unipc.py b/wan/utils/fm_solvers_unipc.py index 57321ba..fb502f2 100644 --- a/wan/utils/fm_solvers_unipc.py +++ b/wan/utils/fm_solvers_unipc.py @@ -8,9 +8,11 @@ from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput) +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) from diffusers.utils import deprecate, is_scipy_available if is_scipy_available(): diff --git a/wan/utils/prompt_extend.py b/wan/utils/prompt_extend.py index 5e3a216..d3b2c71 100644 --- a/wan/utils/prompt_extend.py +++ b/wan/utils/prompt_extend.py @@ -7,7 +7,7 @@ import sys import tempfile from dataclasses import dataclass from http import HTTPStatus -from typing import Optional, Union, List +from typing import List, Optional, Union import dashscope import torch @@ -393,8 +393,11 @@ class QwenPromptExpander(PromptExpander): if self.is_vl: # default: Load the model on the available device(s) - from transformers import (AutoProcessor, AutoTokenizer, - Qwen2_5_VLForConditionalGeneration) + from transformers import ( + AutoProcessor, + AutoTokenizer, + Qwen2_5_VLForConditionalGeneration, + ) try: from .qwen_vl_utils import process_vision_info except: diff --git a/wan/utils/vace_processor.py b/wan/utils/vace_processor.py index 5f7224f..4b742cd 100644 --- a/wan/utils/vace_processor.py +++ b/wan/utils/vace_processor.py @@ -1,9 +1,9 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import numpy as np -from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms.functional as TF +from PIL import Image class VaceImageProcessor(object): diff --git a/wan/vace.py b/wan/vace.py index d792e9b..66c6474 100644 --- a/wan/vace.py +++ b/wan/vace.py @@ -1,28 +1,36 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import os -import sys import gc -import math -import time -import random -import types import logging +import math +import os +import random +import sys +import time import traceback +import types from contextlib import contextmanager from functools import partial -from PIL import Image -import torchvision.transforms.functional as TF import torch -import torch.nn.functional as F import torch.cuda.amp as amp import torch.distributed as dist import torch.multiprocessing as mp +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from PIL import Image from tqdm import tqdm -from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler) from .modules.vace_model import VaceWanModel +from .text2video import ( + FlowDPMSolverMultistepScheduler, + FlowUniPCMultistepScheduler, + T5EncoderModel, + WanT2V, + WanVAE, + get_sampling_sigmas, + retrieve_timesteps, + shard_model, +) from .utils.vace_processor import VaceVideoProcessor @@ -87,12 +95,13 @@ class WanVace(WanT2V): self.model.eval().requires_grad_(False) if use_usp: - from xfuser.core.distributed import \ - get_sequence_parallel_world_size + from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward, - usp_dit_forward_vace) + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace, + ) for block in self.model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn) @@ -514,8 +523,10 @@ class WanVaceMP(WanVace): world_size=world_size ) - from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) + from xfuser.core.distributed import ( + init_distributed_environment, + initialize_model_parallel, + ) init_distributed_environment( rank=dist.get_rank(), world_size=dist.get_world_size()) @@ -547,9 +558,12 @@ class WanVaceMP(WanVace): if self.use_usp: from xfuser.core.distributed import get_sequence_parallel_world_size - from .distributed.xdit_context_parallel import (usp_attn_forward, - usp_dit_forward, - usp_dit_forward_vace) + + from .distributed.xdit_context_parallel import ( + usp_attn_forward, + usp_dit_forward, + usp_dit_forward_vace, + ) for block in model.blocks: block.self_attn.forward = types.MethodType( usp_attn_forward, block.self_attn)