mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
add free stroage api for FSDP (#277)
Signed-off-by: wenting.swt <wenting.swt@alibaba-inc.com>
This commit is contained in:
parent
11a6a00822
commit
bc3249d61c
@ -1,11 +1,12 @@
|
|||||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import gc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
||||||
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
||||||
|
from torch.distributed.utils import _free_storage
|
||||||
|
|
||||||
def shard_model(
|
def shard_model(
|
||||||
model,
|
model,
|
||||||
@ -30,3 +31,11 @@ def shard_model(
|
|||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
sync_module_states=sync_module_states)
|
sync_module_states=sync_module_states)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def free_model(model):
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, FSDP):
|
||||||
|
_free_storage(m._handle.flat_param.data)
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user