mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
- Introduced `--cpu_offload` argument in `generate.py` for enabling CPU offload. - Updated `WanI2V` class in `image2video.py` to handle CPU offload during model initialization and sharding. - Added new functions in `fsdp.py` for CPU initialization and sharding with CPU offload. - Expanded supported sizes in `configs/__init__.py` to include additional resolutions.
121 lines
3.8 KiB
Python
121 lines
3.8 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import gc
|
|
from functools import partial
|
|
|
|
import torch
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
|
from torch.distributed.utils import _free_storage
|
|
|
|
|
|
def shard_model(
|
|
model,
|
|
device_id,
|
|
param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.float32,
|
|
buffer_dtype=torch.float32,
|
|
process_group=None,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
sync_module_states=True,
|
|
):
|
|
model = FSDP(
|
|
module=model,
|
|
process_group=process_group,
|
|
sharding_strategy=sharding_strategy,
|
|
auto_wrap_policy=partial(
|
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
|
mixed_precision=MixedPrecision(
|
|
param_dtype=param_dtype,
|
|
reduce_dtype=reduce_dtype,
|
|
buffer_dtype=buffer_dtype),
|
|
device_id=device_id,
|
|
sync_module_states=sync_module_states)
|
|
return model
|
|
|
|
|
|
def shard_model_cpu_init(
|
|
model,
|
|
device_id,
|
|
param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.float32,
|
|
buffer_dtype=torch.float32,
|
|
process_group=None,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
sync_module_states=False, # Key fix: disable sync_module_states for CPU init
|
|
):
|
|
"""
|
|
Initialize FSDP with CPU-first approach to reduce GPU memory peak.
|
|
"""
|
|
# Ensure model is on CPU and clear GPU cache
|
|
model = model.cpu()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Initialize FSDP with sync_module_states=False to allow CPU initialization
|
|
model = FSDP(
|
|
module=model,
|
|
process_group=process_group,
|
|
sharding_strategy=sharding_strategy,
|
|
auto_wrap_policy=partial(
|
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
|
mixed_precision=MixedPrecision(
|
|
param_dtype=param_dtype,
|
|
reduce_dtype=reduce_dtype,
|
|
buffer_dtype=buffer_dtype),
|
|
device_id=None, # Keep on CPU during initialization
|
|
sync_module_states=sync_module_states) # Disabled for CPU init
|
|
|
|
# Move to GPU after FSDP wrapping
|
|
model = model.to(f"cuda:{device_id}")
|
|
|
|
# Manual synchronization across ranks if needed
|
|
if sync_module_states and torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
|
|
return model
|
|
|
|
|
|
def shard_model_with_cpu_offload(
|
|
model,
|
|
device_id,
|
|
param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.float32,
|
|
buffer_dtype=torch.float32,
|
|
process_group=None,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
sync_module_states=True,
|
|
):
|
|
"""
|
|
Initialize FSDP with CPU offload to minimize GPU memory usage.
|
|
This keeps parameters on CPU and only loads them to GPU when needed.
|
|
"""
|
|
from torch.distributed.fsdp import CPUOffload
|
|
|
|
# Ensure model is on CPU and clear GPU cache aggressively
|
|
model = model.cpu()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Use CPU offload to minimize GPU memory usage
|
|
model = FSDP(
|
|
module=model,
|
|
process_group=process_group,
|
|
sharding_strategy=sharding_strategy,
|
|
auto_wrap_policy=partial(
|
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
|
mixed_precision=MixedPrecision(
|
|
param_dtype=param_dtype,
|
|
reduce_dtype=reduce_dtype,
|
|
buffer_dtype=buffer_dtype),
|
|
cpu_offload=CPUOffload(offload_params=True), # Key: keep params on CPU
|
|
device_id=device_id,
|
|
sync_module_states=sync_module_states)
|
|
|
|
return model
|
|
|
|
def free_model(model):
|
|
for m in model.modules():
|
|
if isinstance(m, FSDP):
|
|
_free_storage(m._handle.flat_param.data)
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache() |