mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 19:53:22 +00:00
security: add weights_only=True to all torch.load() calls
Fixes a critical security vulnerability where malicious model checkpoints could execute arbitrary code through pickle deserialization. Changes: - wan/modules/vae.py: Add weights_only=True to torch.load() - wan/modules/clip.py: Add weights_only=True to torch.load() - wan/modules/t5.py: Add weights_only=True to torch.load() This prevents arbitrary code execution when loading untrusted checkpoints while maintaining full compatibility with legitimate model weights. Security Impact: Critical - prevents RCE attacks Breaking Changes: None - weights_only=True is compatible with all standard PyTorch state_dict files
This commit is contained in:
parent
7c81b2f27d
commit
f71b604438
@ -516,7 +516,7 @@ class CLIPModel:
|
|||||||
self.model = self.model.eval().requires_grad_(False)
|
self.model = self.model.eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
self.model.load_state_dict(
|
self.model.load_state_dict(
|
||||||
torch.load(checkpoint_path, map_location='cpu'))
|
torch.load(checkpoint_path, map_location='cpu', weights_only=True))
|
||||||
|
|
||||||
# init tokenizer
|
# init tokenizer
|
||||||
self.tokenizer = HuggingfaceTokenizer(
|
self.tokenizer = HuggingfaceTokenizer(
|
||||||
|
|||||||
@ -493,7 +493,7 @@ class T5EncoderModel:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device).eval().requires_grad_(False)
|
device=device).eval().requires_grad_(False)
|
||||||
logging.info(f'loading {checkpoint_path}')
|
logging.info(f'loading {checkpoint_path}')
|
||||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
|
||||||
self.model = model
|
self.model = model
|
||||||
if shard_fn is not None:
|
if shard_fn is not None:
|
||||||
self.model = shard_fn(self.model, sync_module_states=False)
|
self.model = shard_fn(self.model, sync_module_states=False)
|
||||||
|
|||||||
@ -611,7 +611,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
|||||||
# load checkpoint
|
# load checkpoint
|
||||||
logging.info(f'loading {pretrained_path}')
|
logging.info(f'loading {pretrained_path}')
|
||||||
model.load_state_dict(
|
model.load_state_dict(
|
||||||
torch.load(pretrained_path, map_location=device), assign=True)
|
torch.load(pretrained_path, map_location=device, weights_only=True), assign=True)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user