mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	add free stroage api for FSDP (#277)
Signed-off-by: wenting.swt <wenting.swt@alibaba-inc.com>
This commit is contained in:
		
							parent
							
								
									11a6a00822
								
							
						
					
					
						commit
						bc3249d61c
					
				@ -1,11 +1,12 @@
 | 
			
		||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
			
		||||
import gc
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
			
		||||
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 | 
			
		||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 | 
			
		||||
 | 
			
		||||
from torch.distributed.utils import _free_storage
 | 
			
		||||
 | 
			
		||||
def shard_model(
 | 
			
		||||
    model,
 | 
			
		||||
@ -30,3 +31,11 @@ def shard_model(
 | 
			
		||||
        device_id=device_id,
 | 
			
		||||
        sync_module_states=sync_module_states)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
def free_model(model):
 | 
			
		||||
    for m in model.modules():
 | 
			
		||||
        if isinstance(m, FSDP):
 | 
			
		||||
            _free_storage(m._handle.flat_param.data)
 | 
			
		||||
    del model
 | 
			
		||||
    gc.collect()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user