Wan2.1/wan/distributed/fsdp.py
ErrorDing 9f90033ce4 Add CPU offload support for FSDP in Wan model
- Introduced `--cpu_offload` argument in `generate.py` for enabling CPU offload.
- Updated `WanI2V` class in `image2video.py` to handle CPU offload during model initialization and sharding.
- Added new functions in `fsdp.py` for CPU initialization and sharding with CPU offload.
- Expanded supported sizes in `configs/__init__.py` to include additional resolutions.
2025-07-22 21:03:47 +08:00

121 lines
3.8 KiB
Python

# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
):
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
device_id=device_id,
sync_module_states=sync_module_states)
return model
def shard_model_cpu_init(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=False, # Key fix: disable sync_module_states for CPU init
):
"""
Initialize FSDP with CPU-first approach to reduce GPU memory peak.
"""
# Ensure model is on CPU and clear GPU cache
model = model.cpu()
torch.cuda.empty_cache()
# Initialize FSDP with sync_module_states=False to allow CPU initialization
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
device_id=None, # Keep on CPU during initialization
sync_module_states=sync_module_states) # Disabled for CPU init
# Move to GPU after FSDP wrapping
model = model.to(f"cuda:{device_id}")
# Manual synchronization across ranks if needed
if sync_module_states and torch.distributed.is_initialized():
torch.distributed.barrier()
return model
def shard_model_with_cpu_offload(
model,
device_id,
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
process_group=None,
sharding_strategy=ShardingStrategy.FULL_SHARD,
sync_module_states=True,
):
"""
Initialize FSDP with CPU offload to minimize GPU memory usage.
This keeps parameters on CPU and only loads them to GPU when needed.
"""
from torch.distributed.fsdp import CPUOffload
# Ensure model is on CPU and clear GPU cache aggressively
model = model.cpu()
torch.cuda.empty_cache()
# Use CPU offload to minimize GPU memory usage
model = FSDP(
module=model,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=partial(
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype),
cpu_offload=CPUOffload(offload_params=True), # Key: keep params on CPU
device_id=device_id,
sync_module_states=sync_module_states)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()