mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-05 14:54:54 +00:00
* isort the code * format the code * Add yapf config file * Remove torch cuda memory profiler
44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
|
import gc
|
|
from functools import partial
|
|
|
|
import torch
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
|
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
|
from torch.distributed.utils import _free_storage
|
|
|
|
|
|
def shard_model(
|
|
model,
|
|
device_id,
|
|
param_dtype=torch.bfloat16,
|
|
reduce_dtype=torch.float32,
|
|
buffer_dtype=torch.float32,
|
|
process_group=None,
|
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
sync_module_states=True,
|
|
):
|
|
model = FSDP(
|
|
module=model,
|
|
process_group=process_group,
|
|
sharding_strategy=sharding_strategy,
|
|
auto_wrap_policy=partial(
|
|
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
|
mixed_precision=MixedPrecision(
|
|
param_dtype=param_dtype,
|
|
reduce_dtype=reduce_dtype,
|
|
buffer_dtype=buffer_dtype),
|
|
device_id=device_id,
|
|
sync_module_states=sync_module_states)
|
|
return model
|
|
|
|
|
|
def free_model(model):
|
|
for m in model.modules():
|
|
if isinstance(m, FSDP):
|
|
_free_storage(m._handle.flat_param.data)
|
|
del model
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|