mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
fixed LTX Video support for RTX10xx , 20xx
This commit is contained in:
parent
5ba133cae9
commit
8b9929f744
@ -189,7 +189,7 @@ class TextEncoder(nn.Module):
|
|||||||
if "llm" in text_encoder_type:
|
if "llm" in text_encoder_type:
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
|
forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
|
||||||
self.model= offload.fast_load_transformers_model(self.model_path, forcedConfigPath=forcedConfigPath)
|
self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath)
|
||||||
if forcedConfigPath != None:
|
if forcedConfigPath != None:
|
||||||
self.model.final_layer_norm = self.model.model.norm
|
self.model.final_layer_norm = self.model.model.norm
|
||||||
|
|
||||||
|
|||||||
@ -134,6 +134,11 @@ def get_supported_attention_modes():
|
|||||||
if not sage2_supported:
|
if not sage2_supported:
|
||||||
if "sage2" in ret:
|
if "sage2" in ret:
|
||||||
ret.remove("sage2")
|
ret.remove("sage2")
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
if major < 7:
|
||||||
|
if "sage" in ret:
|
||||||
|
ret.remove("sage")
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -173,6 +178,8 @@ def pay_attention(
|
|||||||
# assume if k_lens is non null, each k is padded up to lk (one k out of two will need to be discarded or ignored)
|
# assume if k_lens is non null, each k is padded up to lk (one k out of two will need to be discarded or ignored)
|
||||||
if attention_mask != None:
|
if attention_mask != None:
|
||||||
force_attention = "sdpa"
|
force_attention = "sdpa"
|
||||||
|
if attention_mask.dtype == torch.bfloat16 and not bfloat16_supported:
|
||||||
|
attention_mask = attention_mask.to(torch.float16)
|
||||||
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
attn = offload.shared_state["_attention"] if force_attention== None else force_attention
|
||||||
|
|
||||||
q,k,v = qkv_list
|
q,k,v = qkv_list
|
||||||
|
|||||||
4
wgp.py
4
wgp.py
@ -1869,9 +1869,9 @@ def get_ltxv_text_encoder_filename(text_encoder_quantization):
|
|||||||
|
|
||||||
def get_hunyuan_text_encoder_filename(text_encoder_quantization):
|
def get_hunyuan_text_encoder_filename(text_encoder_quantization):
|
||||||
if text_encoder_quantization =="int8":
|
if text_encoder_quantization =="int8":
|
||||||
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors"
|
|
||||||
else:
|
|
||||||
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors"
|
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors"
|
||||||
|
else:
|
||||||
|
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors"
|
||||||
|
|
||||||
return text_encoder_filename
|
return text_encoder_filename
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user