mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Add CPU offload support for FSDP in Wan model
- Introduced `--cpu_offload` argument in `generate.py` for enabling CPU offload. - Updated `WanI2V` class in `image2video.py` to handle CPU offload during model initialization and sharding. - Added new functions in `fsdp.py` for CPU initialization and sharding with CPU offload. - Expanded supported sizes in `configs/__init__.py` to include additional resolutions.
This commit is contained in:
		
							parent
							
								
									7c81b2f27d
								
							
						
					
					
						commit
						9f90033ce4
					
				@ -155,6 +155,11 @@ def _parse_args():
 | 
				
			|||||||
        action="store_true",
 | 
					        action="store_true",
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
        help="Whether to use FSDP for DiT.")
 | 
					        help="Whether to use FSDP for DiT.")
 | 
				
			||||||
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "--cpu_offload",
 | 
				
			||||||
 | 
					        action="store_true",
 | 
				
			||||||
 | 
					        default=False,
 | 
				
			||||||
 | 
					        help="Whether to use CPU offload for FSDP on Wan model, only works with dit_fsdp.")
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
        "--save_file",
 | 
					        "--save_file",
 | 
				
			||||||
        type=str,
 | 
					        type=str,
 | 
				
			||||||
@ -421,6 +426,7 @@ def generate(args):
 | 
				
			|||||||
            rank=rank,
 | 
					            rank=rank,
 | 
				
			||||||
            t5_fsdp=args.t5_fsdp,
 | 
					            t5_fsdp=args.t5_fsdp,
 | 
				
			||||||
            dit_fsdp=args.dit_fsdp,
 | 
					            dit_fsdp=args.dit_fsdp,
 | 
				
			||||||
 | 
					            cpu_offload=args.cpu_offload,
 | 
				
			||||||
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
					            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
 | 
				
			||||||
            t5_cpu=args.t5_cpu,
 | 
					            t5_cpu=args.t5_cpu,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								start_master.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								start_master.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Set environment variables
 | 
				
			||||||
 | 
					export PATH=/data/apps/wan-21/wan-env/bin:$PATH
 | 
				
			||||||
 | 
					export MASTER_ADDR='10.79.79.197'
 | 
				
			||||||
 | 
					export MASTER_PORT='7860'
 | 
				
			||||||
 | 
					export RANK='0'
 | 
				
			||||||
 | 
					export WORLD_SIZE='2'
 | 
				
			||||||
 | 
					export LOCAL_RANK='0'
 | 
				
			||||||
 | 
					export NCCL_SOCKET_IFNAME='enp65s0'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Run the generate.py script with specified parameters
 | 
				
			||||||
 | 
					# like: ./start_master.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
 | 
				
			||||||
 | 
					python generate.py "$@"
 | 
				
			||||||
							
								
								
									
										14
									
								
								start_rank1.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								start_rank1.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
				
			|||||||
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Set environment variables
 | 
				
			||||||
 | 
					export PATH=/data/apps/wan-21/wan-env/bin:$PATH
 | 
				
			||||||
 | 
					export MASTER_ADDR='10.79.79.197'
 | 
				
			||||||
 | 
					export MASTER_PORT='7860'
 | 
				
			||||||
 | 
					export RANK='1'
 | 
				
			||||||
 | 
					export WORLD_SIZE='2'
 | 
				
			||||||
 | 
					export LOCAL_RANK='0'
 | 
				
			||||||
 | 
					export NCCL_SOCKET_IFNAME='enp33s0'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Run the generate.py script with specified parameters
 | 
				
			||||||
 | 
					# like: ./start_rank1.sh --dit_fsdp --task i2v-14B --size 480*640 --ckpt_dir ./Wan2.1-I2V-14B-480P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
 | 
				
			||||||
 | 
					python generate.py "$@"
 | 
				
			||||||
@ -32,6 +32,8 @@ SIZE_CONFIGS = {
 | 
				
			|||||||
    '1280*720': (1280, 720),
 | 
					    '1280*720': (1280, 720),
 | 
				
			||||||
    '480*832': (480, 832),
 | 
					    '480*832': (480, 832),
 | 
				
			||||||
    '832*480': (832, 480),
 | 
					    '832*480': (832, 480),
 | 
				
			||||||
 | 
					    '480*640': (480, 640),
 | 
				
			||||||
 | 
					    '640*480': (640, 480),
 | 
				
			||||||
    '1024*1024': (1024, 1024),
 | 
					    '1024*1024': (1024, 1024),
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -40,12 +42,14 @@ MAX_AREA_CONFIGS = {
 | 
				
			|||||||
    '1280*720': 1280 * 720,
 | 
					    '1280*720': 1280 * 720,
 | 
				
			||||||
    '480*832': 480 * 832,
 | 
					    '480*832': 480 * 832,
 | 
				
			||||||
    '832*480': 832 * 480,
 | 
					    '832*480': 832 * 480,
 | 
				
			||||||
 | 
					    '480*640': 480 * 640,
 | 
				
			||||||
 | 
					    '640*480': 640 * 480,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SUPPORTED_SIZES = {
 | 
					SUPPORTED_SIZES = {
 | 
				
			||||||
    't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
 | 
					    't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
 | 
				
			||||||
    't2v-1.3B': ('480*832', '832*480'),
 | 
					    't2v-1.3B': ('480*832', '832*480'),
 | 
				
			||||||
    'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
 | 
					    'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '480*640', '640*480'),
 | 
				
			||||||
    'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
 | 
					    'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
 | 
				
			||||||
    't2i-14B': tuple(SIZE_CONFIGS.keys()),
 | 
					    't2i-14B': tuple(SIZE_CONFIGS.keys()),
 | 
				
			||||||
    'vace-1.3B': ('480*832', '832*480'),
 | 
					    'vace-1.3B': ('480*832', '832*480'),
 | 
				
			||||||
 | 
				
			|||||||
@ -34,10 +34,88 @@ def shard_model(
 | 
				
			|||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def shard_model_cpu_init(
 | 
				
			||||||
 | 
					    model,
 | 
				
			||||||
 | 
					    device_id,
 | 
				
			||||||
 | 
					    param_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					    reduce_dtype=torch.float32,
 | 
				
			||||||
 | 
					    buffer_dtype=torch.float32,
 | 
				
			||||||
 | 
					    process_group=None,
 | 
				
			||||||
 | 
					    sharding_strategy=ShardingStrategy.FULL_SHARD,
 | 
				
			||||||
 | 
					    sync_module_states=False,  # Key fix: disable sync_module_states for CPU init
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Initialize FSDP with CPU-first approach to reduce GPU memory peak.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Ensure model is on CPU and clear GPU cache
 | 
				
			||||||
 | 
					    model = model.cpu()
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Initialize FSDP with sync_module_states=False to allow CPU initialization
 | 
				
			||||||
 | 
					    model = FSDP(
 | 
				
			||||||
 | 
					        module=model,
 | 
				
			||||||
 | 
					        process_group=process_group,
 | 
				
			||||||
 | 
					        sharding_strategy=sharding_strategy,
 | 
				
			||||||
 | 
					        auto_wrap_policy=partial(
 | 
				
			||||||
 | 
					            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
 | 
				
			||||||
 | 
					        mixed_precision=MixedPrecision(
 | 
				
			||||||
 | 
					            param_dtype=param_dtype,
 | 
				
			||||||
 | 
					            reduce_dtype=reduce_dtype,
 | 
				
			||||||
 | 
					            buffer_dtype=buffer_dtype),
 | 
				
			||||||
 | 
					        device_id=None,  # Keep on CPU during initialization
 | 
				
			||||||
 | 
					        sync_module_states=sync_module_states)  # Disabled for CPU init
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Move to GPU after FSDP wrapping
 | 
				
			||||||
 | 
					    model = model.to(f"cuda:{device_id}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Manual synchronization across ranks if needed
 | 
				
			||||||
 | 
					    if sync_module_states and torch.distributed.is_initialized():
 | 
				
			||||||
 | 
					        torch.distributed.barrier()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def shard_model_with_cpu_offload(
 | 
				
			||||||
 | 
					    model,
 | 
				
			||||||
 | 
					    device_id,
 | 
				
			||||||
 | 
					    param_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					    reduce_dtype=torch.float32,
 | 
				
			||||||
 | 
					    buffer_dtype=torch.float32,
 | 
				
			||||||
 | 
					    process_group=None,
 | 
				
			||||||
 | 
					    sharding_strategy=ShardingStrategy.FULL_SHARD,
 | 
				
			||||||
 | 
					    sync_module_states=True,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Initialize FSDP with CPU offload to minimize GPU memory usage.
 | 
				
			||||||
 | 
					    This keeps parameters on CPU and only loads them to GPU when needed.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    from torch.distributed.fsdp import CPUOffload
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Ensure model is on CPU and clear GPU cache aggressively
 | 
				
			||||||
 | 
					    model = model.cpu()
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Use CPU offload to minimize GPU memory usage
 | 
				
			||||||
 | 
					    model = FSDP(
 | 
				
			||||||
 | 
					        module=model,
 | 
				
			||||||
 | 
					        process_group=process_group,
 | 
				
			||||||
 | 
					        sharding_strategy=sharding_strategy,
 | 
				
			||||||
 | 
					        auto_wrap_policy=partial(
 | 
				
			||||||
 | 
					            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
 | 
				
			||||||
 | 
					        mixed_precision=MixedPrecision(
 | 
				
			||||||
 | 
					            param_dtype=param_dtype,
 | 
				
			||||||
 | 
					            reduce_dtype=reduce_dtype,
 | 
				
			||||||
 | 
					            buffer_dtype=buffer_dtype),
 | 
				
			||||||
 | 
					        cpu_offload=CPUOffload(offload_params=True),  # Key: keep params on CPU
 | 
				
			||||||
 | 
					        device_id=device_id,
 | 
				
			||||||
 | 
					        sync_module_states=sync_module_states)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def free_model(model):
 | 
					def free_model(model):
 | 
				
			||||||
    for m in model.modules():
 | 
					    for m in model.modules():
 | 
				
			||||||
        if isinstance(m, FSDP):
 | 
					        if isinstance(m, FSDP):
 | 
				
			||||||
            _free_storage(m._handle.flat_param.data)
 | 
					            _free_storage(m._handle.flat_param.data)
 | 
				
			||||||
    del model
 | 
					    del model
 | 
				
			||||||
    gc.collect()
 | 
					    gc.collect()
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
@ -16,7 +16,7 @@ import torch.distributed as dist
 | 
				
			|||||||
import torchvision.transforms.functional as TF
 | 
					import torchvision.transforms.functional as TF
 | 
				
			||||||
from tqdm import tqdm
 | 
					from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .distributed.fsdp import shard_model
 | 
					from .distributed.fsdp import shard_model, shard_model_with_cpu_offload
 | 
				
			||||||
from .modules.clip import CLIPModel
 | 
					from .modules.clip import CLIPModel
 | 
				
			||||||
from .modules.model import WanModel
 | 
					from .modules.model import WanModel
 | 
				
			||||||
from .modules.t5 import T5EncoderModel
 | 
					from .modules.t5 import T5EncoderModel
 | 
				
			||||||
@ -39,6 +39,7 @@ class WanI2V:
 | 
				
			|||||||
        rank=0,
 | 
					        rank=0,
 | 
				
			||||||
        t5_fsdp=False,
 | 
					        t5_fsdp=False,
 | 
				
			||||||
        dit_fsdp=False,
 | 
					        dit_fsdp=False,
 | 
				
			||||||
 | 
					        cpu_offload=False,
 | 
				
			||||||
        use_usp=False,
 | 
					        use_usp=False,
 | 
				
			||||||
        t5_cpu=False,
 | 
					        t5_cpu=False,
 | 
				
			||||||
        init_on_cpu=True,
 | 
					        init_on_cpu=True,
 | 
				
			||||||
@ -59,6 +60,8 @@ class WanI2V:
 | 
				
			|||||||
                Enable FSDP sharding for T5 model
 | 
					                Enable FSDP sharding for T5 model
 | 
				
			||||||
            dit_fsdp (`bool`, *optional*, defaults to False):
 | 
					            dit_fsdp (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Enable FSDP sharding for DiT model
 | 
					                Enable FSDP sharding for DiT model
 | 
				
			||||||
 | 
					            cpu_offload (`bool`, *optional*, defaults to False):
 | 
				
			||||||
 | 
					                Enable CPU offload for FSDP on Wan model, only works with dit_fsdp.
 | 
				
			||||||
            use_usp (`bool`, *optional*, defaults to False):
 | 
					            use_usp (`bool`, *optional*, defaults to False):
 | 
				
			||||||
                Enable distribution strategy of USP.
 | 
					                Enable distribution strategy of USP.
 | 
				
			||||||
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
					            t5_cpu (`bool`, *optional*, defaults to False):
 | 
				
			||||||
@ -71,10 +74,14 @@ class WanI2V:
 | 
				
			|||||||
        self.rank = rank
 | 
					        self.rank = rank
 | 
				
			||||||
        self.use_usp = use_usp
 | 
					        self.use_usp = use_usp
 | 
				
			||||||
        self.t5_cpu = t5_cpu
 | 
					        self.t5_cpu = t5_cpu
 | 
				
			||||||
 | 
					        self.dit_fsdp = dit_fsdp
 | 
				
			||||||
 | 
					        self.cpu_offload = cpu_offload
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.num_train_timesteps = config.num_train_timesteps
 | 
					        self.num_train_timesteps = config.num_train_timesteps
 | 
				
			||||||
        self.param_dtype = config.param_dtype
 | 
					        self.param_dtype = config.param_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert not cpu_offload or (cpu_offload and dit_fsdp), "When cpu_offload is True, dit_fsdp must also be True"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        shard_fn = partial(shard_model, device_id=device_id)
 | 
					        shard_fn = partial(shard_model, device_id=device_id)
 | 
				
			||||||
        self.text_encoder = T5EncoderModel(
 | 
					        self.text_encoder = T5EncoderModel(
 | 
				
			||||||
            text_len=config.text_len,
 | 
					            text_len=config.text_len,
 | 
				
			||||||
@ -99,9 +106,15 @@ class WanI2V:
 | 
				
			|||||||
            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
 | 
					            tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
					        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
				
			||||||
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
					        if cpu_offload:
 | 
				
			||||||
 | 
					            # For FSDP and cpu_offload, force CPU initialization to avoid OOM
 | 
				
			||||||
 | 
					            with torch.device('cpu'):
 | 
				
			||||||
 | 
					                self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
				
			||||||
        self.model.eval().requires_grad_(False)
 | 
					        self.model.eval().requires_grad_(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if t5_fsdp or dit_fsdp or use_usp:
 | 
					        if t5_fsdp or dit_fsdp or use_usp:
 | 
				
			||||||
            init_on_cpu = False
 | 
					            init_on_cpu = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -123,7 +136,11 @@ class WanI2V:
 | 
				
			|||||||
        if dist.is_initialized():
 | 
					        if dist.is_initialized():
 | 
				
			||||||
            dist.barrier()
 | 
					            dist.barrier()
 | 
				
			||||||
        if dit_fsdp:
 | 
					        if dit_fsdp:
 | 
				
			||||||
            self.model = shard_fn(self.model)
 | 
					            if cpu_offload:
 | 
				
			||||||
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					                self.model = shard_model_with_cpu_offload(self.model, device_id=device_id)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.model = shard_fn(self.model)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if not init_on_cpu:
 | 
					            if not init_on_cpu:
 | 
				
			||||||
                self.model.to(self.device)
 | 
					                self.model.to(self.device)
 | 
				
			||||||
@ -298,7 +315,10 @@ class WanI2V:
 | 
				
			|||||||
            if offload_model:
 | 
					            if offload_model:
 | 
				
			||||||
                torch.cuda.empty_cache()
 | 
					                torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.model.to(self.device)
 | 
					            if self.cpu_offload:
 | 
				
			||||||
 | 
					                self.model.cpu()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.model.to(self.device)
 | 
				
			||||||
            for _, t in enumerate(tqdm(timesteps)):
 | 
					            for _, t in enumerate(tqdm(timesteps)):
 | 
				
			||||||
                latent_model_input = [latent.to(self.device)]
 | 
					                latent_model_input = [latent.to(self.device)]
 | 
				
			||||||
                timestep = [t]
 | 
					                timestep = [t]
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user