mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-07 07:44:53 +00:00
Adapted model for macOS with M1 Pro chip and other improvements
This commit is contained in:
parent
b562f86ec5
commit
cf578ab14b
2
.gitignore
vendored
2
.gitignore
vendored
@ -34,3 +34,5 @@ Wan2.1-T2V-14B/
|
|||||||
Wan2.1-T2V-1.3B/
|
Wan2.1-T2V-1.3B/
|
||||||
Wan2.1-I2V-14B-480P/
|
Wan2.1-I2V-14B-480P/
|
||||||
Wan2.1-I2V-14B-720P/
|
Wan2.1-I2V-14B-720P/
|
||||||
|
venv_wan/
|
||||||
|
venv_wan_py310/
|
||||||
|
187
generate.py
187
generate.py
@ -186,6 +186,11 @@ def _parse_args():
|
|||||||
type=float,
|
type=float,
|
||||||
default=5.0,
|
default=5.0,
|
||||||
help="Classifier free guidance scale.")
|
help="Classifier free guidance scale.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Device to use for computation (mps, cpu).")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -207,43 +212,21 @@ def _init_logging(rank):
|
|||||||
|
|
||||||
|
|
||||||
def generate(args):
|
def generate(args):
|
||||||
rank = int(os.getenv("RANK", 0))
|
# Set device based on args or availability
|
||||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
if args.device:
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
device = torch.device(args.device)
|
||||||
device = local_rank
|
else:
|
||||||
_init_logging(rank)
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||||
|
|
||||||
|
_init_logging(0) # Use rank 0 logging for single-device
|
||||||
|
|
||||||
|
# Ensure all torch operations use this device
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
if args.offload_model is None:
|
if args.offload_model is None:
|
||||||
args.offload_model = False if world_size > 1 else True
|
args.offload_model = True # Default to True for single device to save memory
|
||||||
logging.info(
|
logging.info(
|
||||||
f"offload_model is not specified, set to {args.offload_model}.")
|
f"offload_model is not specified, set to {args.offload_model}.")
|
||||||
if world_size > 1:
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
dist.init_process_group(
|
|
||||||
backend="nccl",
|
|
||||||
init_method="env://",
|
|
||||||
rank=rank,
|
|
||||||
world_size=world_size)
|
|
||||||
else:
|
|
||||||
assert not (
|
|
||||||
args.t5_fsdp or args.dit_fsdp
|
|
||||||
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
|
|
||||||
assert not (
|
|
||||||
args.ulysses_size > 1 or args.ring_size > 1
|
|
||||||
), f"context parallel are not supported in non-distributed environments."
|
|
||||||
|
|
||||||
if args.ulysses_size > 1 or args.ring_size > 1:
|
|
||||||
assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
|
|
||||||
from xfuser.core.distributed import (initialize_model_parallel,
|
|
||||||
init_distributed_environment)
|
|
||||||
init_distributed_environment(
|
|
||||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
|
||||||
|
|
||||||
initialize_model_parallel(
|
|
||||||
sequence_parallel_degree=dist.get_world_size(),
|
|
||||||
ring_degree=args.ring_size,
|
|
||||||
ulysses_degree=args.ulysses_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.use_prompt_extend:
|
if args.use_prompt_extend:
|
||||||
if args.prompt_extend_method == "dashscope":
|
if args.prompt_extend_method == "dashscope":
|
||||||
@ -253,58 +236,44 @@ def generate(args):
|
|||||||
prompt_expander = QwenPromptExpander(
|
prompt_expander = QwenPromptExpander(
|
||||||
model_name=args.prompt_extend_model,
|
model_name=args.prompt_extend_model,
|
||||||
is_vl="i2v" in args.task,
|
is_vl="i2v" in args.task,
|
||||||
device=rank)
|
device=device) # Use MPS/CPU device instead of rank
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
||||||
|
|
||||||
cfg = WAN_CONFIGS[args.task]
|
cfg = WAN_CONFIGS[args.task]
|
||||||
if args.ulysses_size > 1:
|
|
||||||
assert cfg.num_heads % args.ulysses_size == 0, f"`num_heads` must be divisible by `ulysses_size`."
|
|
||||||
|
|
||||||
logging.info(f"Generation job args: {args}")
|
logging.info(f"Generation job args: {args}")
|
||||||
logging.info(f"Generation model config: {cfg}")
|
logging.info(f"Generation model config: {cfg}")
|
||||||
|
|
||||||
if dist.is_initialized():
|
|
||||||
base_seed = [args.base_seed] if rank == 0 else [None]
|
|
||||||
dist.broadcast_object_list(base_seed, src=0)
|
|
||||||
args.base_seed = base_seed[0]
|
|
||||||
|
|
||||||
if "t2v" in args.task or "t2i" in args.task:
|
if "t2v" in args.task or "t2i" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
logging.info(f"Input prompt: {args.prompt}")
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
if args.use_prompt_extend:
|
if args.use_prompt_extend:
|
||||||
logging.info("Extending prompt ...")
|
logging.info("Extending prompt ...")
|
||||||
if rank == 0:
|
prompt_output = prompt_expander(
|
||||||
prompt_output = prompt_expander(
|
args.prompt,
|
||||||
args.prompt,
|
tar_lang=args.prompt_extend_target_lang,
|
||||||
tar_lang=args.prompt_extend_target_lang,
|
seed=args.base_seed)
|
||||||
seed=args.base_seed)
|
if prompt_output.status == False:
|
||||||
if prompt_output.status == False:
|
logging.info(
|
||||||
logging.info(
|
f"Extending prompt failed: {prompt_output.message}")
|
||||||
f"Extending prompt failed: {prompt_output.message}")
|
logging.info("Falling back to original prompt.")
|
||||||
logging.info("Falling back to original prompt.")
|
input_prompt = args.prompt
|
||||||
input_prompt = args.prompt
|
|
||||||
else:
|
|
||||||
input_prompt = prompt_output.prompt
|
|
||||||
input_prompt = [input_prompt]
|
|
||||||
else:
|
else:
|
||||||
input_prompt = [None]
|
input_prompt = prompt_output.prompt
|
||||||
if dist.is_initialized():
|
args.prompt = input_prompt
|
||||||
dist.broadcast_object_list(input_prompt, src=0)
|
|
||||||
args.prompt = input_prompt[0]
|
|
||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanT2V pipeline.")
|
logging.info("Creating WanT2V pipeline.")
|
||||||
wan_t2v = wan.WanT2V(
|
wan_t2v = wan.WanT2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
device_id=device,
|
device_id=device, # Use MPS/CPU device instead of local_rank
|
||||||
rank=rank,
|
rank=0, # Single device, so use rank 0
|
||||||
t5_fsdp=args.t5_fsdp,
|
t5_fsdp=False, # Disable FSDP (not supported on MPS)
|
||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=False, # Disable FSDP (not supported on MPS)
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=False, # Disable Ulysses/ring parallelism (single device)
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -332,36 +301,30 @@ def generate(args):
|
|||||||
img = Image.open(args.image).convert("RGB")
|
img = Image.open(args.image).convert("RGB")
|
||||||
if args.use_prompt_extend:
|
if args.use_prompt_extend:
|
||||||
logging.info("Extending prompt ...")
|
logging.info("Extending prompt ...")
|
||||||
if rank == 0:
|
prompt_output = prompt_expander(
|
||||||
prompt_output = prompt_expander(
|
args.prompt,
|
||||||
args.prompt,
|
tar_lang=args.prompt_extend_target_lang,
|
||||||
tar_lang=args.prompt_extend_target_lang,
|
image=img,
|
||||||
image=img,
|
seed=args.base_seed)
|
||||||
seed=args.base_seed)
|
if prompt_output.status == False:
|
||||||
if prompt_output.status == False:
|
logging.info(
|
||||||
logging.info(
|
f"Extending prompt failed: {prompt_output.message}")
|
||||||
f"Extending prompt failed: {prompt_output.message}")
|
logging.info("Falling back to original prompt.")
|
||||||
logging.info("Falling back to original prompt.")
|
input_prompt = args.prompt
|
||||||
input_prompt = args.prompt
|
|
||||||
else:
|
|
||||||
input_prompt = prompt_output.prompt
|
|
||||||
input_prompt = [input_prompt]
|
|
||||||
else:
|
else:
|
||||||
input_prompt = [None]
|
input_prompt = prompt_output.prompt
|
||||||
if dist.is_initialized():
|
args.prompt = input_prompt
|
||||||
dist.broadcast_object_list(input_prompt, src=0)
|
|
||||||
args.prompt = input_prompt[0]
|
|
||||||
logging.info(f"Extended prompt: {args.prompt}")
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
logging.info("Creating WanI2V pipeline.")
|
logging.info("Creating WanI2V pipeline.")
|
||||||
wan_i2v = wan.WanI2V(
|
wan_i2v = wan.WanI2V(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
checkpoint_dir=args.ckpt_dir,
|
checkpoint_dir=args.ckpt_dir,
|
||||||
device_id=device,
|
device_id=device, # Use MPS/CPU device instead of local_rank
|
||||||
rank=rank,
|
rank=0, # Single device, so use rank 0
|
||||||
t5_fsdp=args.t5_fsdp,
|
t5_fsdp=False, # Disable FSDP (not supported on MPS)
|
||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=False, # Disable FSDP (not supported on MPS)
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=False, # Disable Ulysses/ring parallelism (single device)
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -378,34 +341,32 @@ def generate(args):
|
|||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
if rank == 0:
|
# Save output (single device, so no rank check needed)
|
||||||
if args.save_file is None:
|
if args.save_file is None:
|
||||||
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
formatted_prompt = args.prompt.replace(" ", "_").replace("/", "_")[:50]
|
||||||
"_")[:50]
|
suffix = '.png' if "t2i" in args.task else '.mp4'
|
||||||
suffix = '.png' if "t2i" in args.task else '.mp4'
|
args.save_file = f"{args.task}_{args.size}_{formatted_prompt}_{formatted_time}" + suffix
|
||||||
args.save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix
|
|
||||||
|
|
||||||
if "t2i" in args.task:
|
if "t2i" in args.task:
|
||||||
logging.info(f"Saving generated image to {args.save_file}")
|
logging.info(f"Saving generated image to {args.save_file}")
|
||||||
cache_image(
|
cache_image(
|
||||||
tensor=video.squeeze(1)[None],
|
tensor=video.squeeze(1)[None],
|
||||||
save_file=args.save_file,
|
save_file=args.save_file,
|
||||||
nrow=1,
|
nrow=1,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
value_range=(-1, 1))
|
value_range=(-1, 1))
|
||||||
else:
|
else:
|
||||||
logging.info(f"Saving generated video to {args.save_file}")
|
logging.info(f"Saving generated video to {args.save_file}")
|
||||||
cache_video(
|
cache_video(
|
||||||
tensor=video[None],
|
tensor=video[None],
|
||||||
save_file=args.save_file,
|
save_file=args.save_file,
|
||||||
fps=cfg.sample_fps,
|
fps=cfg.sample_fps,
|
||||||
nrow=1,
|
nrow=1,
|
||||||
normalize=True,
|
normalize=True,
|
||||||
value_range=(-1, 1))
|
value_range=(-1, 1))
|
||||||
logging.info("Finished.")
|
logging.info("Finished.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
generate(args)
|
generate(args)
|
||||||
|
@ -11,6 +11,6 @@ easydict
|
|||||||
ftfy
|
ftfy
|
||||||
dashscope
|
dashscope
|
||||||
imageio-ffmpeg
|
imageio-ffmpeg
|
||||||
flash_attn
|
# flash_attn
|
||||||
gradio>=5.0.0
|
gradio>=5.0.0
|
||||||
numpy>=1.23.5,<2
|
numpy>=1.23.5,<2
|
||||||
|
@ -7,11 +7,11 @@ wan_shared_cfg = EasyDict()
|
|||||||
|
|
||||||
# t5
|
# t5
|
||||||
wan_shared_cfg.t5_model = 'umt5_xxl'
|
wan_shared_cfg.t5_model = 'umt5_xxl'
|
||||||
wan_shared_cfg.t5_dtype = torch.bfloat16
|
wan_shared_cfg.t5_dtype = torch.float32
|
||||||
wan_shared_cfg.text_len = 512
|
wan_shared_cfg.text_len = 512
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
wan_shared_cfg.param_dtype = torch.bfloat16
|
wan_shared_cfg.param_dtype = torch.float32
|
||||||
|
|
||||||
# inference
|
# inference
|
||||||
wan_shared_cfg.num_train_timesteps = 1000
|
wan_shared_cfg.num_train_timesteps = 1000
|
||||||
|
@ -10,7 +10,7 @@ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
|||||||
def shard_model(
|
def shard_model(
|
||||||
model,
|
model,
|
||||||
device_id,
|
device_id,
|
||||||
param_dtype=torch.bfloat16,
|
param_dtype=torch.float32,
|
||||||
reduce_dtype=torch.float32,
|
reduce_dtype=torch.float32,
|
||||||
buffer_dtype=torch.float32,
|
buffer_dtype=torch.float32,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
|
@ -151,9 +151,9 @@ def usp_attn_forward(self,
|
|||||||
seq_lens,
|
seq_lens,
|
||||||
grid_sizes,
|
grid_sizes,
|
||||||
freqs,
|
freqs,
|
||||||
dtype=torch.bfloat16):
|
dtype=torch.float32):
|
||||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
half_dtypes = (torch.float16, torch.bfloat16)
|
half_dtypes = (torch.float16, torch.float32)
|
||||||
|
|
||||||
def half(x):
|
def half(x):
|
||||||
return x if x.dtype in half_dtypes else x.to(dtype)
|
return x if x.dtype in half_dtypes else x.to(dtype)
|
||||||
|
@ -63,7 +63,14 @@ class WanI2V:
|
|||||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||||
"""
|
"""
|
||||||
self.device = torch.device(f"cuda:{device_id}")
|
# Check if device_id is a torch.device instance
|
||||||
|
if isinstance(device_id, torch.device):
|
||||||
|
self.device = device_id
|
||||||
|
elif device_id == "mps" or (isinstance(device_id, int) and device_id == -1):
|
||||||
|
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
|
@ -33,25 +33,26 @@ def flash_attention(
|
|||||||
causal=False,
|
causal=False,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float32,
|
||||||
version=None,
|
version=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q: [B, Lq, Nq, C1].
|
Flash attention implementation with fallback for CPU and MPS devices
|
||||||
k: [B, Lk, Nk, C1].
|
|
||||||
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
|
||||||
q_lens: [B].
|
|
||||||
k_lens: [B].
|
|
||||||
dropout_p: float. Dropout probability.
|
|
||||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
||||||
causal: bool. Whether to apply causal attention mask.
|
|
||||||
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
|
||||||
deterministic: bool. If True, slightly slower and uses more memory.
|
|
||||||
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
|
||||||
"""
|
"""
|
||||||
half_dtypes = (torch.float16, torch.bfloat16)
|
half_dtypes = (torch.float16, torch.float32)
|
||||||
assert dtype in half_dtypes
|
assert dtype in half_dtypes
|
||||||
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
assert q.size(-1) <= 256, "Sequence length exceeds the maximum limit."
|
||||||
|
|
||||||
|
# Add CPU/MPS fallback implementation
|
||||||
|
if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE) or q.device.type in ['cpu', 'mps']:
|
||||||
|
# Implement standard attention for CPU/MPS
|
||||||
|
return attention(q, k, v,
|
||||||
|
q_lens=q_lens,
|
||||||
|
k_lens=k_lens,
|
||||||
|
dropout_p=dropout_p,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size)
|
||||||
|
|
||||||
# params
|
# params
|
||||||
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
||||||
@ -142,7 +143,7 @@ def attention(
|
|||||||
causal=False,
|
causal=False,
|
||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float32,
|
||||||
fa_version=None,
|
fa_version=None,
|
||||||
):
|
):
|
||||||
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
||||||
|
@ -16,7 +16,7 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
# preprocess
|
# preprocess
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
position = position.type(torch.float64)
|
position = position.type(torch.float32)
|
||||||
|
|
||||||
# calculation
|
# calculation
|
||||||
sinusoid = torch.outer(
|
sinusoid = torch.outer(
|
||||||
@ -31,7 +31,7 @@ def rope_params(max_seq_len, dim, theta=10000):
|
|||||||
freqs = torch.outer(
|
freqs = torch.outer(
|
||||||
torch.arange(max_seq_len),
|
torch.arange(max_seq_len),
|
||||||
1.0 / torch.pow(theta,
|
1.0 / torch.pow(theta,
|
||||||
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
seq_len = f * h * w
|
seq_len = f * h * w
|
||||||
|
|
||||||
# precompute multipliers
|
# precompute multipliers
|
||||||
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
|
||||||
seq_len, n, -1, 2))
|
seq_len, n, -1, 2))
|
||||||
freqs_i = torch.cat([
|
freqs_i = torch.cat([
|
||||||
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||||
|
@ -61,7 +61,7 @@ class T5LayerNorm(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||||
self.eps)
|
self.eps)
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
if self.weight.dtype in [torch.float16, torch.float32]:
|
||||||
x = x.type_as(self.weight)
|
x = x.type_as(self.weight)
|
||||||
return self.weight * x
|
return self.weight * x
|
||||||
|
|
||||||
@ -474,8 +474,8 @@ class T5EncoderModel:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_len,
|
text_len,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.float32,
|
||||||
device=torch.cuda.current_device(),
|
device='mps' if torch.backends.mps.is_available() else 'cpu',
|
||||||
checkpoint_path=None,
|
checkpoint_path=None,
|
||||||
tokenizer_path=None,
|
tokenizer_path=None,
|
||||||
shard_fn=None,
|
shard_fn=None,
|
||||||
|
@ -44,7 +44,7 @@ LM_EN_SYS_PROMPT = \
|
|||||||
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
'''1. For overly concise user inputs, reasonably infer and add details to make the video more complete and appealing without altering the original intent;\n''' \
|
||||||
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
'''2. Enhance the main features in user descriptions (e.g., appearance, expression, quantity, race, posture, etc.), visual style, spatial relationships, and shot scales;\n''' \
|
||||||
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
'''3. Output the entire prompt in English, retaining original text in quotes and titles, and preserving key input information;\n''' \
|
||||||
'''4. Prompts should match the user’s intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
|
'''4. Prompts should match the user's intent and accurately reflect the specified style. If the user does not specify a style, choose the most appropriate style for the video;\n''' \
|
||||||
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
'''5. Emphasize motion information and different camera movements present in the input description;\n''' \
|
||||||
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
'''6. Your output should have natural motion attributes. For the target category described, add natural actions of the target using simple and direct verbs;\n''' \
|
||||||
'''7. The revised prompt should be around 80-100 characters long.\n''' \
|
'''7. The revised prompt should be around 80-100 characters long.\n''' \
|
||||||
@ -82,7 +82,7 @@ VL_EN_SYS_PROMPT = \
|
|||||||
'''1. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \
|
'''1. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \
|
||||||
'''2. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \
|
'''2. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \
|
||||||
'''3. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \
|
'''3. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \
|
||||||
'''4. The prompt should match the user’s intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
|
'''4. The prompt should match the user's intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
|
||||||
'''5. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \
|
'''5. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \
|
||||||
'''6. You need to emphasize movement information in the input and different camera angles;\n''' \
|
'''6. You need to emphasize movement information in the input and different camera angles;\n''' \
|
||||||
'''7. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \
|
'''7. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \
|
||||||
@ -93,7 +93,7 @@ VL_EN_SYS_PROMPT = \
|
|||||||
'''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \
|
'''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \
|
||||||
'''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \
|
'''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \
|
||||||
'''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \
|
'''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \
|
||||||
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There's a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
|
||||||
'''Directly output the rewritten English text.'''
|
'''Directly output the rewritten English text.'''
|
||||||
|
|
||||||
|
|
||||||
@ -347,7 +347,7 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
use_fast=True)
|
use_fast=True)
|
||||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=torch.bfloat16 if FLASH_VER == 2 else
|
torch_dtype=torch.float32 if FLASH_VER == 2 else
|
||||||
torch.float16 if "AWQ" in self.model_name else "auto",
|
torch.float16 if "AWQ" in self.model_name else "auto",
|
||||||
attn_implementation="flash_attention_2"
|
attn_implementation="flash_attention_2"
|
||||||
if FLASH_VER == 2 else None,
|
if FLASH_VER == 2 else None,
|
||||||
@ -363,6 +363,16 @@ class QwenPromptExpander(PromptExpander):
|
|||||||
device_map="cpu")
|
device_map="cpu")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
# Initialize device
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
self.device = device
|
||||||
|
elif device == "mps" or (isinstance(device, str) and "mps" in device):
|
||||||
|
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||||
|
elif isinstance(device, int) and device == -1:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
self.device = torch.device(f"cuda:{device}" if isinstance(device, int) else device)
|
||||||
|
|
||||||
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
|
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
messages = [{
|
messages = [{
|
||||||
|
@ -274,6 +274,12 @@ def get_video_reader_backend() -> str:
|
|||||||
def fetch_video(
|
def fetch_video(
|
||||||
ele: dict,
|
ele: dict,
|
||||||
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
|
image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
|
||||||
|
# Handle MPS device compatibility
|
||||||
|
original_device = None
|
||||||
|
if isinstance(ele.get("video"), torch.Tensor) and ele["video"].device.type == "mps":
|
||||||
|
original_device = ele["video"].device
|
||||||
|
ele["video"] = ele["video"].cpu()
|
||||||
|
|
||||||
if isinstance(ele["video"], str):
|
if isinstance(ele["video"], str):
|
||||||
video_reader_backend = get_video_reader_backend()
|
video_reader_backend = get_video_reader_backend()
|
||||||
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
|
||||||
@ -324,6 +330,10 @@ def fetch_video(
|
|||||||
images.extend([images[-1]] * (nframes - len(images)))
|
images.extend([images[-1]] * (nframes - len(images)))
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
# Return to original device if needed
|
||||||
|
if original_device is not None and isinstance(video, torch.Tensor):
|
||||||
|
video = video.to(original_device)
|
||||||
|
|
||||||
|
|
||||||
def extract_vision_info(
|
def extract_vision_info(
|
||||||
conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
||||||
|
Loading…
Reference in New Issue
Block a user