diff --git a/generate.py b/generate.py index c841c19..c555449 100644 --- a/generate.py +++ b/generate.py @@ -155,6 +155,11 @@ def _parse_args(): action="store_true", default=False, help="Whether to use FSDP for DiT.") + parser.add_argument( + "--cpu_offload", + action="store_true", + default=False, + help="Whether to use CPU offload for FSDP on Wan model, only works with dit_fsdp.") parser.add_argument( "--save_file", type=str, @@ -421,6 +426,7 @@ def generate(args): rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, + cpu_offload=args.cpu_offload, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) diff --git a/start_master.sh b/start_master.sh new file mode 100644 index 0000000..92ec3dd --- /dev/null +++ b/start_master.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Set environment variables +export PATH=/data/apps/wan-21/wan-env/bin:$PATH +export MASTER_ADDR='10.79.79.197' +export MASTER_PORT='7860' +export RANK='0' +export WORLD_SIZE='2' +export LOCAL_RANK='0' +export NCCL_SOCKET_IFNAME='enp65s0' + +# Run the generate.py script with specified parameters +# like: ./start_master.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +python generate.py "$@" diff --git a/start_rank1.sh b/start_rank1.sh new file mode 100644 index 0000000..e65d5c4 --- /dev/null +++ b/start_rank1.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Set environment variables +export PATH=/data/apps/wan-21/wan-env/bin:$PATH +export MASTER_ADDR='10.79.79.197' +export MASTER_PORT='7860' +export RANK='1' +export WORLD_SIZE='2' +export LOCAL_RANK='0' +export NCCL_SOCKET_IFNAME='enp33s0' + +# Run the generate.py script with specified parameters +# like: ./start_rank1.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +python generate.py "$@" diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index e7f95d7..9f729f2 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -32,6 +32,8 @@ SIZE_CONFIGS = { '1280*720': (1280, 720), '480*832': (480, 832), '832*480': (832, 480), + '480*640': (480, 640), + '640*480': (640, 480), '1024*1024': (1024, 1024), } @@ -40,12 +42,14 @@ MAX_AREA_CONFIGS = { '1280*720': 1280 * 720, '480*832': 480 * 832, '832*480': 832 * 480, + '480*640': 480 * 640, + '640*480': 640 * 480, } SUPPORTED_SIZES = { 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2v-1.3B': ('480*832', '832*480'), - 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*640', '640*480'), 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), 'vace-1.3B': ('480*832', '832*480'), diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 6bb496d..03e077d 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -34,10 +34,88 @@ def shard_model( return model +def shard_model_cpu_init( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=False, # Key fix: disable sync_module_states for CPU init +): + """ + Initialize FSDP with CPU-first approach to reduce GPU memory peak. + """ + # Ensure model is on CPU and clear GPU cache + model = model.cpu() + torch.cuda.empty_cache() + + # Initialize FSDP with sync_module_states=False to allow CPU initialization + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + device_id=None, # Keep on CPU during initialization + sync_module_states=sync_module_states) # Disabled for CPU init + + # Move to GPU after FSDP wrapping + model = model.to(f"cuda:{device_id}") + + # Manual synchronization across ranks if needed + if sync_module_states and torch.distributed.is_initialized(): + torch.distributed.barrier() + + return model + + +def shard_model_with_cpu_offload( + model, + device_id, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, + process_group=None, + sharding_strategy=ShardingStrategy.FULL_SHARD, + sync_module_states=True, +): + """ + Initialize FSDP with CPU offload to minimize GPU memory usage. + This keeps parameters on CPU and only loads them to GPU when needed. + """ + from torch.distributed.fsdp import CPUOffload + + # Ensure model is on CPU and clear GPU cache aggressively + model = model.cpu() + torch.cuda.empty_cache() + + # Use CPU offload to minimize GPU memory usage + model = FSDP( + module=model, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=partial( + lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), + mixed_precision=MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype), + cpu_offload=CPUOffload(offload_params=True), # Key: keep params on CPU + device_id=device_id, + sync_module_states=sync_module_states) + + return model + def free_model(model): for m in model.modules(): if isinstance(m, FSDP): _free_storage(m._handle.flat_param.data) del model gc.collect() - torch.cuda.empty_cache() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/wan/image2video.py b/wan/image2video.py index 6882c53..0f44e00 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -16,7 +16,7 @@ import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm -from .distributed.fsdp import shard_model +from .distributed.fsdp import shard_model, shard_model_with_cpu_offload from .modules.clip import CLIPModel from .modules.model import WanModel from .modules.t5 import T5EncoderModel @@ -39,6 +39,7 @@ class WanI2V: rank=0, t5_fsdp=False, dit_fsdp=False, + cpu_offload=False, use_usp=False, t5_cpu=False, init_on_cpu=True, @@ -59,6 +60,8 @@ class WanI2V: Enable FSDP sharding for T5 model dit_fsdp (`bool`, *optional*, defaults to False): Enable FSDP sharding for DiT model + cpu_offload (`bool`, *optional*, defaults to False): + Enable CPU offload for FSDP on Wan model, only works with dit_fsdp. use_usp (`bool`, *optional*, defaults to False): Enable distribution strategy of USP. t5_cpu (`bool`, *optional*, defaults to False): @@ -71,10 +74,14 @@ class WanI2V: self.rank = rank self.use_usp = use_usp self.t5_cpu = t5_cpu + self.dit_fsdp = dit_fsdp + self.cpu_offload = cpu_offload self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype + assert not cpu_offload or (cpu_offload and dit_fsdp), "When cpu_offload is True, dit_fsdp must also be True" + shard_fn = partial(shard_model, device_id=device_id) self.text_encoder = T5EncoderModel( text_len=config.text_len, @@ -99,9 +106,15 @@ class WanI2V: tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) logging.info(f"Creating WanModel from {checkpoint_dir}") - self.model = WanModel.from_pretrained(checkpoint_dir) + if cpu_offload: + # For FSDP and cpu_offload, force CPU initialization to avoid OOM + with torch.device('cpu'): + self.model = WanModel.from_pretrained(checkpoint_dir) + else: + self.model = WanModel.from_pretrained(checkpoint_dir) self.model.eval().requires_grad_(False) + if t5_fsdp or dit_fsdp or use_usp: init_on_cpu = False @@ -123,7 +136,11 @@ class WanI2V: if dist.is_initialized(): dist.barrier() if dit_fsdp: - self.model = shard_fn(self.model) + if cpu_offload: + torch.cuda.empty_cache() + self.model = shard_model_with_cpu_offload(self.model, device_id=device_id) + else: + self.model = shard_fn(self.model) else: if not init_on_cpu: self.model.to(self.device) @@ -298,7 +315,10 @@ class WanI2V: if offload_model: torch.cuda.empty_cache() - self.model.to(self.device) + if self.cpu_offload: + self.model.cpu() + else: + self.model.to(self.device) for _, t in enumerate(tqdm(timesteps)): latent_model_input = [latent.to(self.device)] timestep = [t]