mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
add free stroage api for FSDP (#277)
Signed-off-by: wenting.swt <wenting.swt@alibaba-inc.com>
This commit is contained in:
parent
11a6a00822
commit
bc3249d61c
@ -1,11 +1,12 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import gc
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||
|
||||
from torch.distributed.utils import _free_storage
|
||||
|
||||
def shard_model(
|
||||
model,
|
||||
@ -30,3 +31,11 @@ def shard_model(
|
||||
device_id=device_id,
|
||||
sync_module_states=sync_module_states)
|
||||
return model
|
||||
|
||||
def free_model(model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, FSDP):
|
||||
_free_storage(m._handle.flat_param.data)
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
Loading…
Reference in New Issue
Block a user