mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	* isort the code * format the code * Add yapf config file * Remove torch cuda memory profiler
		
			
				
	
	
		
			44 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import gc
 | 
						|
from functools import partial
 | 
						|
 | 
						|
import torch
 | 
						|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
						|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 | 
						|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 | 
						|
from torch.distributed.utils import _free_storage
 | 
						|
 | 
						|
 | 
						|
def shard_model(
 | 
						|
    model,
 | 
						|
    device_id,
 | 
						|
    param_dtype=torch.bfloat16,
 | 
						|
    reduce_dtype=torch.float32,
 | 
						|
    buffer_dtype=torch.float32,
 | 
						|
    process_group=None,
 | 
						|
    sharding_strategy=ShardingStrategy.FULL_SHARD,
 | 
						|
    sync_module_states=True,
 | 
						|
):
 | 
						|
    model = FSDP(
 | 
						|
        module=model,
 | 
						|
        process_group=process_group,
 | 
						|
        sharding_strategy=sharding_strategy,
 | 
						|
        auto_wrap_policy=partial(
 | 
						|
            lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
 | 
						|
        mixed_precision=MixedPrecision(
 | 
						|
            param_dtype=param_dtype,
 | 
						|
            reduce_dtype=reduce_dtype,
 | 
						|
            buffer_dtype=buffer_dtype),
 | 
						|
        device_id=device_id,
 | 
						|
        sync_module_states=sync_module_states)
 | 
						|
    return model
 | 
						|
 | 
						|
 | 
						|
def free_model(model):
 | 
						|
    for m in model.modules():
 | 
						|
        if isinstance(m, FSDP):
 | 
						|
            _free_storage(m._handle.flat_param.data)
 | 
						|
    del model
 | 
						|
    gc.collect()
 | 
						|
    torch.cuda.empty_cache()
 |