diff --git a/hyvideo/text_encoder/__init__.py b/hyvideo/text_encoder/__init__.py index a51aa3f..1376718 100644 --- a/hyvideo/text_encoder/__init__.py +++ b/hyvideo/text_encoder/__init__.py @@ -189,7 +189,7 @@ class TextEncoder(nn.Module): if "llm" in text_encoder_type: from mmgp import offload forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" - self.model= offload.fast_load_transformers_model(self.model_path, forcedConfigPath=forcedConfigPath) + self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) if forcedConfigPath != None: self.model.final_layer_norm = self.model.model.norm diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 635b6d2..3523e48 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -134,6 +134,11 @@ def get_supported_attention_modes(): if not sage2_supported: if "sage2" in ret: ret.remove("sage2") + + major, minor = torch.cuda.get_device_capability() + if major < 7: + if "sage" in ret: + ret.remove("sage") return ret __all__ = [ @@ -173,6 +178,8 @@ def pay_attention( # assume if k_lens is non null, each k is padded up to lk (one k out of two will need to be discarded or ignored) if attention_mask != None: force_attention = "sdpa" + if attention_mask.dtype == torch.bfloat16 and not bfloat16_supported: + attention_mask = attention_mask.to(torch.float16) attn = offload.shared_state["_attention"] if force_attention== None else force_attention q,k,v = qkv_list diff --git a/wgp.py b/wgp.py index 352a877..2eb80bf 100644 --- a/wgp.py +++ b/wgp.py @@ -1869,9 +1869,9 @@ def get_ltxv_text_encoder_filename(text_encoder_quantization): def get_hunyuan_text_encoder_filename(text_encoder_quantization): if text_encoder_quantization =="int8": - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" - else: text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" + else: + text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" return text_encoder_filename