add free stroage api for FSDP

Signed-off-by: wenting.swt <wenting.swt@alibaba-inc.com>
This commit is contained in:
wenting.swt 2025-03-26 13:54:58 +08:00
parent 11a6a00822
commit c9cb9eb9cf

View File

@ -1,11 +1,12 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
from functools import partial
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage
def shard_model(
model,
@ -30,3 +31,11 @@ def shard_model(
device_id=device_id,
sync_module_states=sync_module_states)
return model
def free_model(model):
for m in model.modules():
if isinstance(m, FSDP):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()