security: add weights_only=True to all torch.load() calls

Fixes a critical security vulnerability where malicious model checkpoints
could execute arbitrary code through pickle deserialization.

Changes:
- wan/modules/vae.py: Add weights_only=True to torch.load()
- wan/modules/clip.py: Add weights_only=True to torch.load()
- wan/modules/t5.py: Add weights_only=True to torch.load()

This prevents arbitrary code execution when loading untrusted checkpoints
while maintaining full compatibility with legitimate model weights.

Security Impact: Critical - prevents RCE attacks
Breaking Changes: None - weights_only=True is compatible with all standard
PyTorch state_dict files
This commit is contained in:
Claude 2025-11-19 04:24:14 +00:00
parent 7c81b2f27d
commit f71b604438
No known key found for this signature in database
3 changed files with 3 additions and 3 deletions

View File

@ -516,7 +516,7 @@ class CLIPModel:
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict( self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu')) torch.load(checkpoint_path, map_location='cpu', weights_only=True))
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(

View File

@ -493,7 +493,7 @@ class T5EncoderModel:
dtype=dtype, dtype=dtype,
device=device).eval().requires_grad_(False) device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
self.model = model self.model = model
if shard_fn is not None: if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False) self.model = shard_fn(self.model, sync_module_states=False)

View File

@ -611,7 +611,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
# load checkpoint # load checkpoint
logging.info(f'loading {pretrained_path}') logging.info(f'loading {pretrained_path}')
model.load_state_dict( model.load_state_dict(
torch.load(pretrained_path, map_location=device), assign=True) torch.load(pretrained_path, map_location=device, weights_only=True), assign=True)
return model return model