diff --git a/.gitignore b/.gitignore index 1f975d8..358fda4 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ Wan2.1-T2V-14B/ Wan2.1-T2V-1.3B/ Wan2.1-I2V-14B-480P/ Wan2.1-I2V-14B-720P/ +venv_wan/ +venv_wan_py310/ diff --git a/generate.py b/generate.py index f27bb98..5d57d2c 100644 --- a/generate.py +++ b/generate.py @@ -186,6 +186,11 @@ def _parse_args(): type=float, default=5.0, help="Classifier free guidance scale.") + parser.add_argument( + "--device", + type=str, + default=None, + help="Device to use for computation (mps, cpu).") args = parser.parse_args() @@ -207,43 +212,21 @@ def _init_logging(rank): def generate(args): - rank = int(os.getenv("RANK", 0)) - world_size = int(os.getenv("WORLD_SIZE", 1)) - local_rank = int(os.getenv("LOCAL_RANK", 0)) - device = local_rank - _init_logging(rank) - + # Set device based on args or availability + if args.device: + device = torch.device(args.device) + else: + device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + + _init_logging(0) # Use rank 0 logging for single-device + + # Ensure all torch operations use this device + torch.set_default_device(device) + if args.offload_model is None: - args.offload_model = False if world_size > 1 else True + args.offload_model = True # Default to True for single device to save memory logging.info( f"offload_model is not specified, set to {args.offload_model}.") - if world_size > 1: - torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", - init_method="env://", - rank=rank, - world_size=world_size) - else: - assert not ( - args.t5_fsdp or args.dit_fsdp - ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." - assert not ( - args.ulysses_size > 1 or args.ring_size > 1 - ), f"context parallel are not supported in non-distributed environments." - - if args.ulysses_size > 1 or args.ring_size > 1: - assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." - from xfuser.core.distributed import (initialize_model_parallel, - init_distributed_environment) - init_distributed_environment( - rank=dist.get_rank(), world_size=dist.get_world_size()) - - initialize_model_parallel( - sequence_parallel_degree=dist.get_world_size(), - ring_degree=args.ring_size, - ulysses_degree=args.ulysses_size, - ) if args.use_prompt_extend: if args.prompt_extend_method == "dashscope": @@ -253,58 +236,44 @@ def generate(args): prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl="i2v" in args.task, - device=rank) + device=device) # Use MPS/CPU device instead of rank else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") cfg = WAN_CONFIGS[args.task] - if args.ulysses_size > 1: - assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`." - logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {cfg}") - if dist.is_initialized(): - base_seed = [args.base_seed] if rank == 0 else [None] - dist.broadcast_object_list(base_seed, src=0) - args.base_seed = base_seed[0] - if "t2v" in args.task or "t2i" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] logging.info(f"Input prompt: {args.prompt}") if args.use_prompt_extend: logging.info("Extending prompt ...") - if rank == 0: - prompt_output = prompt_expander( - args.prompt, - tar_lang=args.prompt_extend_target_lang, - seed=args.base_seed) - if prompt_output.status == False: - logging.info( - f"Extending prompt failed: {prompt_output.message}") - logging.info("Falling back to original prompt.") - input_prompt = args.prompt - else: - input_prompt = prompt_output.prompt - input_prompt = [input_prompt] + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt else: - input_prompt = [None] - if dist.is_initialized(): - dist.broadcast_object_list(input_prompt, src=0) - args.prompt = input_prompt[0] + input_prompt = prompt_output.prompt + args.prompt = input_prompt logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + device_id=device, # Use MPS/CPU device instead of local_rank + rank=0, # Single device, so use rank 0 + t5_fsdp=False, # Disable FSDP (not supported on MPS) + dit_fsdp=False, # Disable FSDP (not supported on MPS) + use_usp=False, # Disable Ulysses/ring parallelism (single device) t5_cpu=args.t5_cpu, ) @@ -332,36 +301,30 @@ def generate(args): img = Image.open(args.image).convert("RGB") if args.use_prompt_extend: logging.info("Extending prompt ...") - if rank == 0: - prompt_output = prompt_expander( - args.prompt, - tar_lang=args.prompt_extend_target_lang, - image=img, - seed=args.base_seed) - if prompt_output.status == False: - logging.info( - f"Extending prompt failed: {prompt_output.message}") - logging.info("Falling back to original prompt.") - input_prompt = args.prompt - else: - input_prompt = prompt_output.prompt - input_prompt = [input_prompt] + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + image=img, + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt else: - input_prompt = [None] - if dist.is_initialized(): - dist.broadcast_object_list(input_prompt, src=0) - args.prompt = input_prompt[0] + input_prompt = prompt_output.prompt + args.prompt = input_prompt logging.info(f"Extended prompt: {args.prompt}") logging.info("Creating WanI2V pipeline.") wan_i2v = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir, - device_id=device, - rank=rank, - t5_fsdp=args.t5_fsdp, - dit_fsdp=args.dit_fsdp, - use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + device_id=device, # Use MPS/CPU device instead of local_rank + rank=0, # Single device, so use rank 0 + t5_fsdp=False, # Disable FSDP (not supported on MPS) + dit_fsdp=False, # Disable FSDP (not supported on MPS) + use_usp=False, # Disable Ulysses/ring parallelism (single device) t5_cpu=args.t5_cpu, ) @@ -378,34 +341,32 @@ def generate(args): seed=args.base_seed, offload_model=args.offload_model) - if rank == 0: - if args.save_file is None: - formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") - formatted_prompt = args.prompt.replace(" ", "_").replace("/", - "_")[:50] - suffix = '.png' if "t2i" in args.task else '.mp4' - args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix + # Save output (single device, so no rank check needed) + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50] + suffix = '.png' if "t2i" in args.task else '.mp4' + args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix - if "t2i" in args.task: - logging.info(f"Saving generated image to {args.save_file}") - cache_image( - tensor=video.squeeze(1)[None], - save_file=args.save_file, - nrow=1, - normalize=True, - value_range=(-1, 1)) - else: - logging.info(f"Saving generated video to {args.save_file}") - cache_video( - tensor=video[None], - save_file=args.save_file, - fps=cfg.sample_fps, - nrow=1, - normalize=True, - value_range=(-1, 1)) + if "t2i" in args.task: + logging.info(f"Saving generated image to {args.save_file}") + cache_image( + tensor=video.squeeze(1)[None], + save_file=args.save_file, + nrow=1, + normalize=True, + value_range=(-1, 1)) + else: + logging.info(f"Saving generated video to {args.save_file}") + cache_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) logging.info("Finished.") - if __name__ == "__main__": args = _parse_args() generate(args) diff --git a/requirements.txt b/requirements.txt index d416e7b..e40a45d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,6 @@ easydict ftfy dashscope imageio-ffmpeg -flash_attn +# flash_attn gradio>=5.0.0 numpy>=1.23.5,<2 diff --git a/wan/configs/shared_config.py b/wan/configs/shared_config.py index 04a9f45..62e634d 100644 --- a/wan/configs/shared_config.py +++ b/wan/configs/shared_config.py @@ -7,11 +7,11 @@ wan_shared_cfg = EasyDict() # t5 wan_shared_cfg.t5_model = 'umt5_xxl' -wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.t5_dtype = torch.float32 wan_shared_cfg.text_len = 512 # transformer -wan_shared_cfg.param_dtype = torch.bfloat16 +wan_shared_cfg.param_dtype = torch.float32 # inference wan_shared_cfg.num_train_timesteps = 1000 diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 258d4af..ccbf21f 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -10,7 +10,7 @@ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy def shard_model( model, device_id, - param_dtype=torch.bfloat16, + param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32, process_group=None, diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 01936ce..29dd50f 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -151,9 +151,9 @@ def usp_attn_forward(self, seq_lens, grid_sizes, freqs, - dtype=torch.bfloat16): + dtype=torch.float32): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - half_dtypes = (torch.float16, torch.bfloat16) + half_dtypes = (torch.float16, torch.float32) def half(x): return x if x.dtype in half_dtypes else x.to(dtype) diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..55d3f57 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -63,7 +63,14 @@ class WanI2V: init_on_cpu (`bool`, *optional*, defaults to True): Enable initializing Transformer Model on CPU. Only works without FSDP or USP. """ - self.device = torch.device(f"cuda:{device_id}") + # Check if device_id is a torch.device instance + if isinstance(device_id, torch.device): + self.device = device_id + elif device_id == "mps" or (isinstance(device_id, int) and device_id == -1): + self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + else: + self.device = torch.device(f"cuda:{device_id}") + self.config = config self.rank = rank self.use_usp = use_usp diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 4dbbe03..6dcc54a 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -33,25 +33,26 @@ def flash_attention( causal=False, window_size=(-1, -1), deterministic=False, - dtype=torch.bfloat16, + dtype=torch.float32, version=None, ): """ - q: [B, Lq, Nq, C1]. - k: [B, Lk, Nk, C1]. - v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. - q_lens: [B]. - k_lens: [B]. - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - causal: bool. Whether to apply causal attention mask. - window_size: (left right). If not (-1, -1), apply sliding window local attention. - deterministic: bool. If True, slightly slower and uses more memory. - dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. + Flash attention implementation with fallback for CPU and MPS devices """ - half_dtypes = (torch.float16, torch.bfloat16) + half_dtypes = (torch.float16, torch.float32) assert dtype in half_dtypes - assert q.device.type == 'cuda' and q.size(-1) <= 256 + assert q.size(-1) <= 256, "Sequence length exceeds the maximum limit." + + # Add CPU/MPS fallback implementation + if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE) or q.device.type in ['cpu', 'mps']: + # Implement standard attention for CPU/MPS + return attention(q, k, v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size) # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype @@ -142,7 +143,7 @@ def attention( causal=False, window_size=(-1, -1), deterministic=False, - dtype=torch.bfloat16, + dtype=torch.float32, fa_version=None, ): if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: diff --git a/wan/modules/model.py b/wan/modules/model.py index b65021c..7144055 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -16,7 +16,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.float32) # calculation sinusoid = torch.outer( @@ -31,7 +31,7 @@ def rope_params(max_seq_len, dim, theta=10000): freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, - torch.arange(0, dim, 2).to(torch.float64).div(dim))) + torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs @@ -49,7 +49,7 @@ def rope_apply(x, grid_sizes, freqs): seq_len = f * h * w # precompute multipliers - x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape( + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape( seq_len, n, -1, 2)) freqs_i = torch.cat([ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..3a26ed1 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -61,7 +61,7 @@ class T5LayerNorm(nn.Module): def forward(self, x): x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: + if self.weight.dtype in [torch.float16, torch.float32]: x = x.type_as(self.weight) return self.weight * x @@ -474,8 +474,8 @@ class T5EncoderModel: def __init__( self, text_len, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), + dtype=torch.float32, + device='mps' if torch.backends.mps.is_available() else 'cpu', checkpoint_path=None, tokenizer_path=None, shard_fn=None, diff --git a/wan/utils/prompt_extend.py b/wan/utils/prompt_extend.py index e7a21b5..00f4f20 100644 --- a/wan/utils/prompt_extend.py +++ b/wan/utils/prompt_extend.py @@ -44,7 +44,7 @@ LM_EN_SYS_PROMPT = \ '''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \ '''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \ '''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \ - '''4. Prompts should match the user’s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \ + '''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \ '''5. Emphasize motion information and different camera movements present in the input description;\n''' \ '''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \ '''7. The revised prompt should be around 80-100 characters long.\n''' \ @@ -82,7 +82,7 @@ VL_EN_SYS_PROMPT = \ '''1. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \ '''2. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \ '''3. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \ - '''4. The prompt should match the user’s intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \ + '''4. The prompt should match the user's intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \ '''5. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \ '''6. You need to emphasize movement information in the input and different camera angles;\n''' \ '''7. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \ @@ -93,7 +93,7 @@ VL_EN_SYS_PROMPT = \ '''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \ '''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \ '''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \ - '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \ + '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There's a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \ '''Directly output the rewritten English text.''' @@ -347,7 +347,7 @@ class QwenPromptExpander(PromptExpander): use_fast=True) self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( self.model_name, - torch_dtype=torch.bfloat16 if FLASH_VER == 2 else + torch_dtype=torch.float32 if FLASH_VER == 2 else torch.float16 if "AWQ" in self.model_name else "auto", attn_implementation="flash_attention_2" if FLASH_VER == 2 else None, @@ -363,6 +363,16 @@ class QwenPromptExpander(PromptExpander): device_map="cpu") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + # Initialize device + if isinstance(device, torch.device): + self.device = device + elif device == "mps" or (isinstance(device, str) and "mps" in device): + self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + elif isinstance(device, int) and device == -1: + self.device = torch.device("cpu") + else: + self.device = torch.device(f"cuda:{device}" if isinstance(device, int) else device) + def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): self.model = self.model.to(self.device) messages = [{ diff --git a/wan/utils/qwen_vl_utils.py b/wan/utils/qwen_vl_utils.py index 3c682e6..c8194c2 100644 --- a/wan/utils/qwen_vl_utils.py +++ b/wan/utils/qwen_vl_utils.py @@ -274,6 +274,12 @@ def get_video_reader_backend() -> str: def fetch_video( ele: dict, image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: + # Handle MPS device compatibility + original_device = None + if isinstance(ele.get("video"), torch.Tensor) and ele["video"].device.type == "mps": + original_device = ele["video"].device + ele["video"] = ele["video"].cpu() + if isinstance(ele["video"], str): video_reader_backend = get_video_reader_backend() video = VIDEO_READER_BACKENDS[video_reader_backend](ele) @@ -324,6 +330,10 @@ def fetch_video( images.extend([images[-1]] * (nframes - len(images))) return images + # Return to original device if needed + if original_device is not None and isinstance(video, torch.Tensor): + video = video.to(original_device) + def extract_vision_info( conversations: list[dict] | list[list[dict]]) -> list[dict]: