mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Add CPU offload support for FSDP in Wan model
- Introduced `--cpu_offload` argument in `generate.py` for enabling CPU offload. - Updated `WanI2V` class in `image2video.py` to handle CPU offload during model initialization and sharding. - Added new functions in `fsdp.py` for CPU initialization and sharding with CPU offload. - Expanded supported sizes in `configs/__init__.py` to include additional resolutions.
This commit is contained in:
parent
7c81b2f27d
commit
9f90033ce4
@ -155,6 +155,11 @@ def _parse_args():
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT.")
|
||||
parser.add_argument(
|
||||
"--cpu_offload",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use CPU offload for FSDP on Wan model, only works with dit_fsdp.")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
@ -421,6 +426,7 @@ def generate(args):
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
cpu_offload=args.cpu_offload,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
14
start_master.sh
Normal file
14
start_master.sh
Normal file
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set environment variables
|
||||
export PATH=/data/apps/wan-21/wan-env/bin:$PATH
|
||||
export MASTER_ADDR='10.79.79.197'
|
||||
export MASTER_PORT='7860'
|
||||
export RANK='0'
|
||||
export WORLD_SIZE='2'
|
||||
export LOCAL_RANK='0'
|
||||
export NCCL_SOCKET_IFNAME='enp65s0'
|
||||
|
||||
# Run the generate.py script with specified parameters
|
||||
# like: ./start_master.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||
python generate.py "$@"
|
||||
14
start_rank1.sh
Normal file
14
start_rank1.sh
Normal file
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set environment variables
|
||||
export PATH=/data/apps/wan-21/wan-env/bin:$PATH
|
||||
export MASTER_ADDR='10.79.79.197'
|
||||
export MASTER_PORT='7860'
|
||||
export RANK='1'
|
||||
export WORLD_SIZE='2'
|
||||
export LOCAL_RANK='0'
|
||||
export NCCL_SOCKET_IFNAME='enp33s0'
|
||||
|
||||
# Run the generate.py script with specified parameters
|
||||
# like: ./start_rank1.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||
python generate.py "$@"
|
||||
@ -32,6 +32,8 @@ SIZE_CONFIGS = {
|
||||
'1280*720': (1280, 720),
|
||||
'480*832': (480, 832),
|
||||
'832*480': (832, 480),
|
||||
'480*640': (480, 640),
|
||||
'640*480': (640, 480),
|
||||
'1024*1024': (1024, 1024),
|
||||
}
|
||||
|
||||
@ -40,12 +42,14 @@ MAX_AREA_CONFIGS = {
|
||||
'1280*720': 1280 * 720,
|
||||
'480*832': 480 * 832,
|
||||
'832*480': 832 * 480,
|
||||
'480*640': 480 * 640,
|
||||
'640*480': 640 * 480,
|
||||
}
|
||||
|
||||
SUPPORTED_SIZES = {
|
||||
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2v-1.3B': ('480*832', '832*480'),
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*640', '640*480'),
|
||||
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||
'vace-1.3B': ('480*832', '832*480'),
|
||||
|
||||
@ -34,10 +34,88 @@ def shard_model(
|
||||
return model
|
||||
|
||||
|
||||
def shard_model_cpu_init(
|
||||
model,
|
||||
device_id,
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
process_group=None,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
sync_module_states=False, # Key fix: disable sync_module_states for CPU init
|
||||
):
|
||||
"""
|
||||
Initialize FSDP with CPU-first approach to reduce GPU memory peak.
|
||||
"""
|
||||
# Ensure model is on CPU and clear GPU cache
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Initialize FSDP with sync_module_states=False to allow CPU initialization
|
||||
model = FSDP(
|
||||
module=model,
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
auto_wrap_policy=partial(
|
||||
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
buffer_dtype=buffer_dtype),
|
||||
device_id=None, # Keep on CPU during initialization
|
||||
sync_module_states=sync_module_states) # Disabled for CPU init
|
||||
|
||||
# Move to GPU after FSDP wrapping
|
||||
model = model.to(f"cuda:{device_id}")
|
||||
|
||||
# Manual synchronization across ranks if needed
|
||||
if sync_module_states and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def shard_model_with_cpu_offload(
|
||||
model,
|
||||
device_id,
|
||||
param_dtype=torch.bfloat16,
|
||||
reduce_dtype=torch.float32,
|
||||
buffer_dtype=torch.float32,
|
||||
process_group=None,
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||
sync_module_states=True,
|
||||
):
|
||||
"""
|
||||
Initialize FSDP with CPU offload to minimize GPU memory usage.
|
||||
This keeps parameters on CPU and only loads them to GPU when needed.
|
||||
"""
|
||||
from torch.distributed.fsdp import CPUOffload
|
||||
|
||||
# Ensure model is on CPU and clear GPU cache aggressively
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU offload to minimize GPU memory usage
|
||||
model = FSDP(
|
||||
module=model,
|
||||
process_group=process_group,
|
||||
sharding_strategy=sharding_strategy,
|
||||
auto_wrap_policy=partial(
|
||||
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
||||
mixed_precision=MixedPrecision(
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
buffer_dtype=buffer_dtype),
|
||||
cpu_offload=CPUOffload(offload_params=True), # Key: keep params on CPU
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states)
|
||||
|
||||
return model
|
||||
|
||||
def free_model(model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, FSDP):
|
||||
_free_storage(m._handle.flat_param.data)
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
@ -16,7 +16,7 @@ import torch.distributed as dist
|
||||
import torchvision.transforms.functional as TF
|
||||
from tqdm import tqdm
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .distributed.fsdp import shard_model, shard_model_with_cpu_offload
|
||||
from .modules.clip import CLIPModel
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
@ -39,6 +39,7 @@ class WanI2V:
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
cpu_offload=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
@ -59,6 +60,8 @@ class WanI2V:
|
||||
Enable FSDP sharding for T5 model
|
||||
dit_fsdp (`bool`, *optional*, defaults to False):
|
||||
Enable FSDP sharding for DiT model
|
||||
cpu_offload (`bool`, *optional*, defaults to False):
|
||||
Enable CPU offload for FSDP on Wan model, only works with dit_fsdp.
|
||||
use_usp (`bool`, *optional*, defaults to False):
|
||||
Enable distribution strategy of USP.
|
||||
t5_cpu (`bool`, *optional*, defaults to False):
|
||||
@ -71,10 +74,14 @@ class WanI2V:
|
||||
self.rank = rank
|
||||
self.use_usp = use_usp
|
||||
self.t5_cpu = t5_cpu
|
||||
self.dit_fsdp = dit_fsdp
|
||||
self.cpu_offload = cpu_offload
|
||||
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
assert not cpu_offload or (cpu_offload and dit_fsdp), "When cpu_offload is True, dit_fsdp must also be True"
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
@ -99,9 +106,15 @@ class WanI2V:
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
if cpu_offload:
|
||||
# For FSDP and cpu_offload, force CPU initialization to avoid OOM
|
||||
with torch.device('cpu'):
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
else:
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
init_on_cpu = False
|
||||
|
||||
@ -123,7 +136,11 @@ class WanI2V:
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
if cpu_offload:
|
||||
torch.cuda.empty_cache()
|
||||
self.model = shard_model_with_cpu_offload(self.model, device_id=device_id)
|
||||
else:
|
||||
self.model = shard_fn(self.model)
|
||||
else:
|
||||
if not init_on_cpu:
|
||||
self.model.to(self.device)
|
||||
@ -298,7 +315,10 @@ class WanI2V:
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.model.to(self.device)
|
||||
if self.cpu_offload:
|
||||
self.model.cpu()
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = [latent.to(self.device)]
|
||||
timestep = [t]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user