diff --git a/wan/distributed/fsdp.py b/wan/distributed/fsdp.py index 258d4af..18ba2f3 100644 --- a/wan/distributed/fsdp.py +++ b/wan/distributed/fsdp.py @@ -1,11 +1,12 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc from functools import partial import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy - +from torch.distributed.utils import _free_storage def shard_model( model, @@ -30,3 +31,11 @@ def shard_model( device_id=device_id, sync_module_states=sync_module_states) return model + +def free_model(model): + for m in model.modules(): + if isinstance(m, FSDP): + _free_storage(m._handle.flat_param.data) + del model + gc.collect() + torch.cuda.empty_cache()