mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
Add CPU offload support for FSDP in Wan model
- Introduced `--cpu_offload` argument in `generate.py` for enabling CPU offload. - Updated `WanI2V` class in `image2video.py` to handle CPU offload during model initialization and sharding. - Added new functions in `fsdp.py` for CPU initialization and sharding with CPU offload. - Expanded supported sizes in `configs/__init__.py` to include additional resolutions.
This commit is contained in:
parent
7c81b2f27d
commit
9f90033ce4
@ -155,6 +155,11 @@ def _parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to use FSDP for DiT.")
|
help="Whether to use FSDP for DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu_offload",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Whether to use CPU offload for FSDP on Wan model, only works with dit_fsdp.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_file",
|
"--save_file",
|
||||||
type=str,
|
type=str,
|
||||||
@ -421,6 +426,7 @@ def generate(args):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
t5_fsdp=args.t5_fsdp,
|
t5_fsdp=args.t5_fsdp,
|
||||||
dit_fsdp=args.dit_fsdp,
|
dit_fsdp=args.dit_fsdp,
|
||||||
|
cpu_offload=args.cpu_offload,
|
||||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
t5_cpu=args.t5_cpu,
|
t5_cpu=args.t5_cpu,
|
||||||
)
|
)
|
||||||
|
|||||||
14
start_master.sh
Normal file
14
start_master.sh
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export PATH=/data/apps/wan-21/wan-env/bin:$PATH
|
||||||
|
export MASTER_ADDR='10.79.79.197'
|
||||||
|
export MASTER_PORT='7860'
|
||||||
|
export RANK='0'
|
||||||
|
export WORLD_SIZE='2'
|
||||||
|
export LOCAL_RANK='0'
|
||||||
|
export NCCL_SOCKET_IFNAME='enp65s0'
|
||||||
|
|
||||||
|
# Run the generate.py script with specified parameters
|
||||||
|
# like: ./start_master.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
|
python generate.py "$@"
|
||||||
14
start_rank1.sh
Normal file
14
start_rank1.sh
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export PATH=/data/apps/wan-21/wan-env/bin:$PATH
|
||||||
|
export MASTER_ADDR='10.79.79.197'
|
||||||
|
export MASTER_PORT='7860'
|
||||||
|
export RANK='1'
|
||||||
|
export WORLD_SIZE='2'
|
||||||
|
export LOCAL_RANK='0'
|
||||||
|
export NCCL_SOCKET_IFNAME='enp33s0'
|
||||||
|
|
||||||
|
# Run the generate.py script with specified parameters
|
||||||
|
# like: ./start_rank1.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||||
|
python generate.py "$@"
|
||||||
@ -32,6 +32,8 @@ SIZE_CONFIGS = {
|
|||||||
'1280*720': (1280, 720),
|
'1280*720': (1280, 720),
|
||||||
'480*832': (480, 832),
|
'480*832': (480, 832),
|
||||||
'832*480': (832, 480),
|
'832*480': (832, 480),
|
||||||
|
'480*640': (480, 640),
|
||||||
|
'640*480': (640, 480),
|
||||||
'1024*1024': (1024, 1024),
|
'1024*1024': (1024, 1024),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,12 +42,14 @@ MAX_AREA_CONFIGS = {
|
|||||||
'1280*720': 1280 * 720,
|
'1280*720': 1280 * 720,
|
||||||
'480*832': 480 * 832,
|
'480*832': 480 * 832,
|
||||||
'832*480': 832 * 480,
|
'832*480': 832 * 480,
|
||||||
|
'480*640': 480 * 640,
|
||||||
|
'640*480': 640 * 480,
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_SIZES = {
|
SUPPORTED_SIZES = {
|
||||||
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||||
't2v-1.3B': ('480*832', '832*480'),
|
't2v-1.3B': ('480*832', '832*480'),
|
||||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*640', '640*480'),
|
||||||
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||||
'vace-1.3B': ('480*832', '832*480'),
|
'vace-1.3B': ('480*832', '832*480'),
|
||||||
|
|||||||
@ -34,6 +34,84 @@ def shard_model(
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def shard_model_cpu_init(
|
||||||
|
model,
|
||||||
|
device_id,
|
||||||
|
param_dtype=torch.bfloat16,
|
||||||
|
reduce_dtype=torch.float32,
|
||||||
|
buffer_dtype=torch.float32,
|
||||||
|
process_group=None,
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
sync_module_states=False, # Key fix: disable sync_module_states for CPU init
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize FSDP with CPU-first approach to reduce GPU memory peak.
|
||||||
|
"""
|
||||||
|
# Ensure model is on CPU and clear GPU cache
|
||||||
|
model = model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Initialize FSDP with sync_module_states=False to allow CPU initialization
|
||||||
|
model = FSDP(
|
||||||
|
module=model,
|
||||||
|
process_group=process_group,
|
||||||
|
sharding_strategy=sharding_strategy,
|
||||||
|
auto_wrap_policy=partial(
|
||||||
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=param_dtype,
|
||||||
|
reduce_dtype=reduce_dtype,
|
||||||
|
buffer_dtype=buffer_dtype),
|
||||||
|
device_id=None, # Keep on CPU during initialization
|
||||||
|
sync_module_states=sync_module_states) # Disabled for CPU init
|
||||||
|
|
||||||
|
# Move to GPU after FSDP wrapping
|
||||||
|
model = model.to(f"cuda:{device_id}")
|
||||||
|
|
||||||
|
# Manual synchronization across ranks if needed
|
||||||
|
if sync_module_states and torch.distributed.is_initialized():
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def shard_model_with_cpu_offload(
|
||||||
|
model,
|
||||||
|
device_id,
|
||||||
|
param_dtype=torch.bfloat16,
|
||||||
|
reduce_dtype=torch.float32,
|
||||||
|
buffer_dtype=torch.float32,
|
||||||
|
process_group=None,
|
||||||
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
sync_module_states=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize FSDP with CPU offload to minimize GPU memory usage.
|
||||||
|
This keeps parameters on CPU and only loads them to GPU when needed.
|
||||||
|
"""
|
||||||
|
from torch.distributed.fsdp import CPUOffload
|
||||||
|
|
||||||
|
# Ensure model is on CPU and clear GPU cache aggressively
|
||||||
|
model = model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Use CPU offload to minimize GPU memory usage
|
||||||
|
model = FSDP(
|
||||||
|
module=model,
|
||||||
|
process_group=process_group,
|
||||||
|
sharding_strategy=sharding_strategy,
|
||||||
|
auto_wrap_policy=partial(
|
||||||
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
||||||
|
mixed_precision=MixedPrecision(
|
||||||
|
param_dtype=param_dtype,
|
||||||
|
reduce_dtype=reduce_dtype,
|
||||||
|
buffer_dtype=buffer_dtype),
|
||||||
|
cpu_offload=CPUOffload(offload_params=True), # Key: keep params on CPU
|
||||||
|
device_id=device_id,
|
||||||
|
sync_module_states=sync_module_states)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def free_model(model):
|
def free_model(model):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, FSDP):
|
if isinstance(m, FSDP):
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import torch.distributed as dist
|
|||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .distributed.fsdp import shard_model
|
from .distributed.fsdp import shard_model, shard_model_with_cpu_offload
|
||||||
from .modules.clip import CLIPModel
|
from .modules.clip import CLIPModel
|
||||||
from .modules.model import WanModel
|
from .modules.model import WanModel
|
||||||
from .modules.t5 import T5EncoderModel
|
from .modules.t5 import T5EncoderModel
|
||||||
@ -39,6 +39,7 @@ class WanI2V:
|
|||||||
rank=0,
|
rank=0,
|
||||||
t5_fsdp=False,
|
t5_fsdp=False,
|
||||||
dit_fsdp=False,
|
dit_fsdp=False,
|
||||||
|
cpu_offload=False,
|
||||||
use_usp=False,
|
use_usp=False,
|
||||||
t5_cpu=False,
|
t5_cpu=False,
|
||||||
init_on_cpu=True,
|
init_on_cpu=True,
|
||||||
@ -59,6 +60,8 @@ class WanI2V:
|
|||||||
Enable FSDP sharding for T5 model
|
Enable FSDP sharding for T5 model
|
||||||
dit_fsdp (`bool`, *optional*, defaults to False):
|
dit_fsdp (`bool`, *optional*, defaults to False):
|
||||||
Enable FSDP sharding for DiT model
|
Enable FSDP sharding for DiT model
|
||||||
|
cpu_offload (`bool`, *optional*, defaults to False):
|
||||||
|
Enable CPU offload for FSDP on Wan model, only works with dit_fsdp.
|
||||||
use_usp (`bool`, *optional*, defaults to False):
|
use_usp (`bool`, *optional*, defaults to False):
|
||||||
Enable distribution strategy of USP.
|
Enable distribution strategy of USP.
|
||||||
t5_cpu (`bool`, *optional*, defaults to False):
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
@ -71,10 +74,14 @@ class WanI2V:
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.use_usp = use_usp
|
self.use_usp = use_usp
|
||||||
self.t5_cpu = t5_cpu
|
self.t5_cpu = t5_cpu
|
||||||
|
self.dit_fsdp = dit_fsdp
|
||||||
|
self.cpu_offload = cpu_offload
|
||||||
|
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
|
assert not cpu_offload or (cpu_offload and dit_fsdp), "When cpu_offload is True, dit_fsdp must also be True"
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
@ -99,9 +106,15 @@ class WanI2V:
|
|||||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||||
|
|
||||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||||
|
if cpu_offload:
|
||||||
|
# For FSDP and cpu_offload, force CPU initialization to avoid OOM
|
||||||
|
with torch.device('cpu'):
|
||||||
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
|
else:
|
||||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||||
self.model.eval().requires_grad_(False)
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
if t5_fsdp or dit_fsdp or use_usp:
|
if t5_fsdp or dit_fsdp or use_usp:
|
||||||
init_on_cpu = False
|
init_on_cpu = False
|
||||||
|
|
||||||
@ -123,6 +136,10 @@ class WanI2V:
|
|||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
if dit_fsdp:
|
if dit_fsdp:
|
||||||
|
if cpu_offload:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.model = shard_model_with_cpu_offload(self.model, device_id=device_id)
|
||||||
|
else:
|
||||||
self.model = shard_fn(self.model)
|
self.model = shard_fn(self.model)
|
||||||
else:
|
else:
|
||||||
if not init_on_cpu:
|
if not init_on_cpu:
|
||||||
@ -298,6 +315,9 @@ class WanI2V:
|
|||||||
if offload_model:
|
if offload_model:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.cpu_offload:
|
||||||
|
self.model.cpu()
|
||||||
|
else:
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
for _, t in enumerate(tqdm(timesteps)):
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = [latent.to(self.device)]
|
latent_model_input = [latent.to(self.device)]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user