mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 14:33:15 +00:00
73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
import logging
|
|
|
|
log = logging.getLogger()
|
|
|
|
|
|
def get_parameter_groups(model, cfg, print_log=False):
|
|
"""
|
|
Assign different weight decays and learning rates to different parameters.
|
|
Returns a parameter group which can be passed to the optimizer.
|
|
"""
|
|
weight_decay = cfg.weight_decay
|
|
# embed_weight_decay = cfg.embed_weight_decay
|
|
# backbone_lr_ratio = cfg.backbone_lr_ratio
|
|
base_lr = cfg.learning_rate
|
|
|
|
backbone_params = []
|
|
embed_params = []
|
|
other_params = []
|
|
|
|
# embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe']
|
|
# embedding_names = [e + '.weight' for e in embedding_names]
|
|
|
|
# inspired by detectron2
|
|
memo = set()
|
|
for name, param in model.named_parameters():
|
|
if not param.requires_grad:
|
|
continue
|
|
# Avoid duplicating parameters
|
|
if param in memo:
|
|
continue
|
|
memo.add(param)
|
|
|
|
if name.startswith('module'):
|
|
name = name[7:]
|
|
|
|
inserted = False
|
|
# if name.startswith('pixel_encoder.'):
|
|
# backbone_params.append(param)
|
|
# inserted = True
|
|
# if print_log:
|
|
# log.info(f'{name} counted as a backbone parameter.')
|
|
# else:
|
|
# for e in embedding_names:
|
|
# if name.endswith(e):
|
|
# embed_params.append(param)
|
|
# inserted = True
|
|
# if print_log:
|
|
# log.info(f'{name} counted as an embedding parameter.')
|
|
# break
|
|
|
|
# if not inserted:
|
|
other_params.append(param)
|
|
|
|
parameter_groups = [
|
|
# {
|
|
# 'params': backbone_params,
|
|
# 'lr': base_lr * backbone_lr_ratio,
|
|
# 'weight_decay': weight_decay
|
|
# },
|
|
# {
|
|
# 'params': embed_params,
|
|
# 'lr': base_lr,
|
|
# 'weight_decay': embed_weight_decay
|
|
# },
|
|
{
|
|
'params': other_params,
|
|
'lr': base_lr,
|
|
'weight_decay': weight_decay
|
|
},
|
|
]
|
|
|
|
return parameter_groups
|