From bc3249d61c2de11ca37c74440fb67114fbaa4860 Mon Sep 17 00:00:00 2001 From: Wenting Shen <17aloha@gmail.com> Date: Wed, 26 Mar 2025 15:49:37 +0800 Subject: [PATCH] add free stroage api for FSDP (#277) Signed-off-by: wenting.swt --- wan/distributed/fsdp.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 258d4af..18ba2f3 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -1,11 +1,12 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc from functools import partial import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy - +from torch.distributed.utils import _free_storage def shard_model( model, @@ -30,3 +31,11 @@ def shard_model( device_id=device_id, sync_module_states=sync_module_states) return model + +def free_model(model): + for m in model.modules(): + if isinstance(m, FSDP): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() + torch.cuda.empty_cache()