diff --git a/wgp.py b/wgp.py index e0a0420..d2578b7 100644 --- a/wgp.py +++ b/wgp.py @@ -1746,7 +1746,7 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): else: raw_filename = choices[0] - if dtype == torch.float16 and not any("fp16","FP16") in raw_filename and model_family == "wan" and finetune_def == None : + if dtype == torch.float16 and not any(x in raw_filename for x in ["fp16", "FP16"]) and model_family == "wan" and finetune_def == None: if "quanto_int8" in raw_filename: raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8") elif "quanto_bf16_int8" in raw_filename: