Adapted model for macOS with M1 Pro chip and other improvements

This commit is contained in:
Bakhtiyor Sulaymonov 2025-02-26 22:20:28 +05:00
parent b562f86ec5
commit cf578ab14b
12 changed files with 136 additions and 145 deletions

2
.gitignore vendored
View File

@ -34,3 +34,5 @@ Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
venv_wan/
venv_wan_py310/

View File

@ -186,6 +186,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use for computation (mps, cpu).")
args = parser.parse_args()
@ -207,43 +212,21 @@ def _init_logging(rank):
def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device = local_rank
_init_logging(rank)
# Set device based on args or availability
if args.device:
device = torch.device(args.device)
else:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
_init_logging(0) # Use rank 0 logging for single-device
# Ensure all torch operations use this device
torch.set_default_device(device)
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
args.offload_model = True # Default to True for single device to save memory
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
if args.ulysses_size > 1 or args.ring_size > 1:
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=args.ring_size,
ulysses_degree=args.ulysses_size,
)
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
@ -253,30 +236,21 @@ def generate(args):
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
is_vl="i2v" in args.task,
device=rank)
device=device) # Use MPS/CPU device instead of rank
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
@ -288,23 +262,18 @@ def generate(args):
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
args.prompt = input_prompt
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
device_id=device, # Use MPS/CPU device instead of local_rank
rank=0, # Single device, so use rank 0
t5_fsdp=False, # Disable FSDP (not supported on MPS)
dit_fsdp=False, # Disable FSDP (not supported on MPS)
use_usp=False, # Disable Ulysses/ring parallelism (single device)
t5_cpu=args.t5_cpu,
)
@ -332,7 +301,6 @@ def generate(args):
img = Image.open(args.image).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
@ -345,23 +313,18 @@ def generate(args):
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
args.prompt = input_prompt
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanI2V pipeline.")
wan_i2v = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
device_id=device, # Use MPS/CPU device instead of local_rank
rank=0, # Single device, so use rank 0
t5_fsdp=False, # Disable FSDP (not supported on MPS)
dit_fsdp=False, # Disable FSDP (not supported on MPS)
use_usp=False, # Disable Ulysses/ring parallelism (single device)
t5_cpu=args.t5_cpu,
)
@ -378,13 +341,12 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model)
if rank == 0:
# Save output (single device, so no rank check needed)
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
"_")[:50]
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
suffix = '.png' if "t2i" in args.task else '.mp4'
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix
if "t2i" in args.task:
logging.info(f"Saving generated image to {args.save_file}")
@ -405,7 +367,6 @@ def generate(args):
value_range=(-1, 1))
logging.info("Finished.")
if __name__ == "__main__":
args = _parse_args()
generate(args)

View File

@ -11,6 +11,6 @@ easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
# flash_attn
gradio>=5.0.0
numpy>=1.23.5,<2

View File

@ -7,11 +7,11 @@ wan_shared_cfg = EasyDict()
# t5
wan_shared_cfg.t5_model = 'umt5_xxl'
wan_shared_cfg.t5_dtype = torch.bfloat16
wan_shared_cfg.t5_dtype = torch.float32
wan_shared_cfg.text_len = 512
# transformer
wan_shared_cfg.param_dtype = torch.bfloat16
wan_shared_cfg.param_dtype = torch.float32
# inference
wan_shared_cfg.num_train_timesteps = 1000

View File

@ -10,7 +10,7 @@ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,

View File

@ -151,9 +151,9 @@ def usp_attn_forward(self,
seq_lens,
grid_sizes,
freqs,
dtype=torch.bfloat16):
dtype=torch.float32):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
half_dtypes = (torch.float16, torch.bfloat16)
half_dtypes = (torch.float16, torch.float32)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)

View File

@ -63,7 +63,14 @@ class WanI2V:
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
# Check if device_id is a torch.device instance
if isinstance(device_id, torch.device):
self.device = device_id
elif device_id == "mps" or (isinstance(device_id, int) and device_id == -1):
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
else:
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.use_usp = use_usp

View File

@ -33,25 +33,26 @@ def flash_attention(
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
dtype=torch.float32,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
Flash attention implementation with fallback for CPU and MPS devices
"""
half_dtypes = (torch.float16, torch.bfloat16)
half_dtypes = (torch.float16, torch.float32)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
assert q.size(-1) <= 256, "Sequence length exceeds the maximum limit."
# Add CPU/MPS fallback implementation
if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE) or q.device.type in ['cpu', 'mps']:
# Implement standard attention for CPU/MPS
return attention(q, k, v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size)
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
@ -142,7 +143,7 @@ def attention(
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
dtype=torch.float32,
fa_version=None,
):
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:

View File

@ -16,7 +16,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
half = dim // 2
position = position.type(torch.float64)
position = position.type(torch.float32)
# calculation
sinusoid = torch.outer(
@ -31,7 +31,7 @@ def rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
@ -49,7 +49,7 @@ def rope_apply(x, grid_sizes, freqs):
seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),

View File

@ -61,7 +61,7 @@ class T5LayerNorm(nn.Module):
def forward(self, x):
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
if self.weight.dtype in [torch.float16, torch.float32]:
x = x.type_as(self.weight)
return self.weight * x
@ -474,8 +474,8 @@ class T5EncoderModel:
def __init__(
self,
text_len,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
dtype=torch.float32,
device='mps' if torch.backends.mps.is_available() else 'cpu',
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,

View File

@ -44,7 +44,7 @@ LM_EN_SYS_PROMPT = \
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
'''4. Prompts should match the users intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
'''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
'''7. The revised prompt should be around 80-100 characters long.\n''' \
@ -82,7 +82,7 @@ VL_EN_SYS_PROMPT = \
'''1. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \
'''2. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \
'''3. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \
'''4. The prompt should match the users intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
'''4. The prompt should match the user's intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
'''5. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \
'''6. You need to emphasize movement information in the input and different camera angles;\n''' \
'''7. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \
@ -93,7 +93,7 @@ VL_EN_SYS_PROMPT = \
'''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \
'''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \
'''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There's a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''Directly output the rewritten English text.'''
@ -347,7 +347,7 @@ class QwenPromptExpander(PromptExpander):
use_fast=True)
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16 if FLASH_VER == 2 else
torch_dtype=torch.float32 if FLASH_VER == 2 else
torch.float16 if "AWQ" in self.model_name else "auto",
attn_implementation="flash_attention_2"
if FLASH_VER == 2 else None,
@ -363,6 +363,16 @@ class QwenPromptExpander(PromptExpander):
device_map="cpu")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Initialize device
if isinstance(device, torch.device):
self.device = device
elif device == "mps" or (isinstance(device, str) and "mps" in device):
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
elif isinstance(device, int) and device == -1:
self.device = torch.device("cpu")
else:
self.device = torch.device(f"cuda:{device}" if isinstance(device, int) else device)
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
self.model = self.model.to(self.device)
messages = [{

View File

@ -274,6 +274,12 @@ def get_video_reader_backend() -> str:
def fetch_video(
ele: dict,
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
# Handle MPS device compatibility
original_device = None
if isinstance(ele.get("video"), torch.Tensor) and ele["video"].device.type == "mps":
original_device = ele["video"].device
ele["video"] = ele["video"].cpu()
if isinstance(ele["video"], str):
video_reader_backend = get_video_reader_backend()
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
@ -324,6 +330,10 @@ def fetch_video(
images.extend([images[-1]] * (nframes - len(images)))
return images
# Return to original device if needed
if original_device is not None and isinstance(video, torch.Tensor):
video = video.to(original_device)
def extract_vision_info(
conversations: list[dict] | list[list[dict]]) -> list[dict]: