mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	add free stroage api for FSDP (#277)
Signed-off-by: wenting.swt <wenting.swt@alibaba-inc.com>
This commit is contained in:
		
							parent
							
								
									11a6a00822
								
							
						
					
					
						commit
						bc3249d61c
					
				@ -1,11 +1,12 @@
 | 
				
			|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
					# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
				
			||||||
 | 
					import gc
 | 
				
			||||||
from functools import partial
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
					from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 | 
				
			||||||
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 | 
					from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 | 
				
			||||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 | 
					from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 | 
				
			||||||
 | 
					from torch.distributed.utils import _free_storage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def shard_model(
 | 
					def shard_model(
 | 
				
			||||||
    model,
 | 
					    model,
 | 
				
			||||||
@ -30,3 +31,11 @@ def shard_model(
 | 
				
			|||||||
        device_id=device_id,
 | 
					        device_id=device_id,
 | 
				
			||||||
        sync_module_states=sync_module_states)
 | 
					        sync_module_states=sync_module_states)
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def free_model(model):
 | 
				
			||||||
 | 
					    for m in model.modules():
 | 
				
			||||||
 | 
					        if isinstance(m, FSDP):
 | 
				
			||||||
 | 
					            _free_storage(m._handle.flat_param.data)
 | 
				
			||||||
 | 
					    del model
 | 
				
			||||||
 | 
					    gc.collect()
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user