From f71b60443855ddef82170e64a6c4d24a1d415625 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 19 Nov 2025 04:24:14 +0000 Subject: [PATCH] security: add weights_only=True to all torch.load() calls Fixes a critical security vulnerability where malicious model checkpoints could execute arbitrary code through pickle deserialization. Changes: - wan/modules/vae.py: Add weights_only=True to torch.load() - wan/modules/clip.py: Add weights_only=True to torch.load() - wan/modules/t5.py: Add weights_only=True to torch.load() This prevents arbitrary code execution when loading untrusted checkpoints while maintaining full compatibility with legitimate model weights. Security Impact: Critical - prevents RCE attacks Breaking Changes: None - weights_only=True is compatible with all standard PyTorch state_dict files --- wan/modules/clip.py | 2 +- wan/modules/t5.py | 2 +- wan/modules/vae.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wan/modules/clip.py b/wan/modules/clip.py index 42dda04..7bf859d 100644 --- a/wan/modules/clip.py +++ b/wan/modules/clip.py @@ -516,7 +516,7 @@ class CLIPModel: self.model = self.model.eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') self.model.load_state_dict( - torch.load(checkpoint_path, map_location='cpu')) + torch.load(checkpoint_path, map_location='cpu', weights_only=True)) # init tokenizer self.tokenizer = HuggingfaceTokenizer( diff --git a/wan/modules/t5.py b/wan/modules/t5.py index c841b04..021f1c7 100644 --- a/wan/modules/t5.py +++ b/wan/modules/t5.py @@ -493,7 +493,7 @@ class T5EncoderModel: dtype=dtype, device=device).eval().requires_grad_(False) logging.info(f'loading {checkpoint_path}') - model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True)) self.model = model if shard_fn is not None: self.model = shard_fn(self.model, sync_module_states=False) diff --git a/wan/modules/vae.py b/wan/modules/vae.py index 5c6da57..f7bb78a 100644 --- a/wan/modules/vae.py +++ b/wan/modules/vae.py @@ -611,7 +611,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): # load checkpoint logging.info(f'loading {pretrained_path}') model.load_state_dict( - torch.load(pretrained_path, map_location=device), assign=True) + torch.load(pretrained_path, map_location=device, weights_only=True), assign=True) return model