mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
19 lines
460 B
Python
19 lines
460 B
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import torch
|
|
|
|
|
|
class BaseModel(torch.nn.Module):
|
|
def load(self, path):
|
|
"""Load model from file.
|
|
|
|
Args:
|
|
path (str): file path
|
|
"""
|
|
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
|
|
|
|
if 'optimizer' in parameters:
|
|
parameters = parameters['model']
|
|
|
|
self.load_state_dict(parameters)
|