picky picky

This commit is contained in:
deepbeepmeep 2025-08-06 20:38:13 +02:00
parent b2c3416bcd
commit 979bc20625
175 changed files with 5331 additions and 1553 deletions

2
.gitignore vendored
View File

@ -14,7 +14,7 @@
*.pth *.pth
*.ckpt *.ckpt
*.safetensors *.safetensors
*.json #*.json
# *.txt # *.txt
*.backup *.backup
*.pkl *.pkl

View File

@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates : ## 🔥 Latest Updates :
### August 6 2025: WanGP v7.7 - Picky, picky
This release comes with two new models :
- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likiwise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
Please note that the VAE decoding of Wan 2.2 TextImage is not tiled yet and it may produce VRAM consumption peaks (this doens't mix well with the 720p requirement).
### August 4 2025: WanGP v7.6 - Remuxed ### August 4 2025: WanGP v7.6 - Remuxed
With this new version you won't have any excuse if there is no sound in your video. With this new version you won't have any excuse if there is no sound in your video.

View File

@ -0,0 +1,18 @@
{
"_class_name": "QwenImageTransformer2DModel",
"_diffusers_version": "0.34.0.dev0",
"attention_head_dim": 128,
"axes_dims_rope": [
16,
56,
56
],
"guidance_embeds": false,
"in_channels": 64,
"joint_attention_dim": 3584,
"num_attention_heads": 24,
"num_layers": 60,
"out_channels": 16,
"patch_size": 2,
"pooled_projection_dim": 768
}

14
configs/ti2v_2_2.json Normal file
View File

@ -0,0 +1,14 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.33.0",
"dim": 3072,
"eps": 1e-06,
"ffn_dim": 14336,
"freq_dim": 256,
"in_dim": 48,
"model_type": "ti2v2_2",
"num_heads": 24,
"num_layers": 30,
"out_dim": 48,
"text_len": 512
}

View File

@ -13,7 +13,7 @@
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors", "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-depth-control-diffusers.safetensors",
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors" "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv-097-ic-lora-canny-control-diffusers.safetensors"
], ],
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-dev.yaml" "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-dev.yaml"
}, },
"num_inference_steps": 30 "num_inference_steps": 30
} }

View File

@ -9,7 +9,7 @@
"https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors" "https://huggingface.co/DeepBeepMeep/LTX_Video/resolve/main/ltxv_0.9.8_13B_distilled_quanto_bf16_int8.safetensors"
], ],
"preload_URLs" : "ltxv_13B", "preload_URLs" : "ltxv_13B",
"LTXV_config": "ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml" "LTXV_config": "models/ltx_video/configs/ltxv-13b-0.9.8-distilled.yaml"
}, },
"num_inference_steps": 6 "num_inference_steps": 6
} }

View File

@ -0,0 +1,21 @@
{
"model": {
"name": "Qwen Image 20B",
"architecture": "qwen_image_20B",
"description": "Qwen Image is generative model that will very high quality images. It is one of the few models capable to generate in the image very long texts.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Qwen/resolve/main/qwen_image_20B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Qwen/resolve/main/qwen_image_20B_quanto_bf16_int8.safetensors"
],
"resolutions": [ ["1328x1328 (1:1)", "1328x1328"],
["1664x928 (16:9)", "1664x928"],
["928x1664 (9:16)", "928x1664"],
["1472x1140 (4:3)", "1472x1140"],
["1140x1472 (3:4)", "1140x1472"]
],
"image_outputs": true
},
"prompt": "draw a hat",
"resolution": "1280x720",
"batch_size": 1
}

View File

@ -1,13 +0,0 @@
try:
from ._version import (
version as __version__, # type: ignore
version_tuple,
)
except ImportError:
__version__ = "unknown (no version information available)"
version_tuple = (0, 0, "unknown", "noinfo")
from pathlib import Path
PACKAGE = __package__.replace("_", "-")
PACKAGE_ROOT = Path(__file__).parent

1
loras_qwen/Readme.txt Normal file
View File

@ -0,0 +1 @@
LTX Video loras

2
models/flux/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from .flux_main import model_factory
from . import flux_handler

103
models/flux/flux_handler.py Normal file
View File

@ -0,0 +1,103 @@
import torch
def get_ltxv_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors"
if text_encoder_quantization =="int8":
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
return text_encoder_filename
class family_handler():
@staticmethod
def query_model_def(base_model_type, model_def):
flux_model = model_def.get("flux-model", "flux-dev")
flux_schnell = flux_model == "flux-schnell"
model_def_output = {
"image_outputs" : True,
"no_negative_prompt" : True,
}
if flux_schnell:
model_def_output["no_guidance"] = True
else:
model_def_output["embedded_guidance"] = True
return model_def_output
@staticmethod
def query_supported_types():
return ["flux"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def get_rgb_factors(model_type):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("flux")
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_family():
return "flux"
@staticmethod
def query_family_infos():
return {"flux":(30, "Flux 1")}
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization)
return [
{
"repoId" : "DeepBeepMeep/Flux",
"sourceFolderList" : [""],
"fileList" : [ ["flux_vae.safetensors"] ]
},
{
"repoId" : "DeepBeepMeep/LTX_Video",
"sourceFolderList" : ["T5_xxl_1.1"],
"fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename) ]
},
{
"repoId" : "DeepBeepMeep/HunyuanVideo",
"sourceFolderList" : [ "clip_vit_large_patch14", ],
"fileList" :[
["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"],
]
}
]
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
from .flux_main import model_factory
flux_model = model_factory(
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
text_encoder_filename= get_ltxv_text_encoder_filename(text_encoder_quantization),
quantizeTransformer = quantizeTransformer,
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
)
pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5}
return flux_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"embedded_guidance": 2.5,
})
if model_def.get("reference_image", False):
ui_defaults.update({
"video_prompt_type": "KI",
})

View File

@ -5,10 +5,10 @@ from dataclasses import dataclass
from glob import iglob from glob import iglob
from mmgp import offload as offload from mmgp import offload as offload
import torch import torch
from wan.utils.utils import calculate_new_dimensions from shared.utils.utils import calculate_new_dimensions
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack from .sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.modules.layers import get_linear_split_map from .modules.layers import get_linear_split_map
from flux.util import ( from .util import (
aspect_ratio_to_height_width, aspect_ratio_to_height_width,
load_ae, load_ae,
load_clip, load_clip,
@ -146,13 +146,3 @@ class model_factory:
x = x.transpose(0, 1) x = x.transpose(0, 1)
return x return x
def query_model_def(model_type, model_def):
flux_model = model_def.get("flux-model", "flux-dev")
flux_schnell = flux_model == "flux-schnell"
model_def_output = {
"image_outputs" : True,
}
if flux_schnell:
model_def_output["no_guidance"] = True
return model_def_output

View File

@ -1,7 +1,7 @@
import torch import torch
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from wan.modules.attention import pay_attention from shared.attention import pay_attention
def attention(qkv_list, pe: Tensor) -> Tensor: def attention(qkv_list, pe: Tensor) -> Tensor:

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from flux.modules.layers import ( from .modules.layers import (
DoubleStreamBlock, DoubleStreamBlock,
EmbedND, EmbedND,
LastLayer, LastLayer,
@ -11,7 +11,7 @@ from flux.modules.layers import (
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
) )
from flux.modules.lora import LinearLora, replace_linear_with_lora from .modules.lora import LinearLora, replace_linear_with_lora
@dataclass @dataclass

View File

@ -7,7 +7,7 @@ from safetensors.torch import load_file as load_sft
from torch import nn from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from flux.util import print_load_warning from ..util import print_load_warning
class DepthImageEncoder: class DepthImageEncoder:

View File

@ -5,7 +5,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from flux.math import attention, rope from ..math import attention, rope
def get_linear_split_map(): def get_linear_split_map():
hidden_size = 3072 hidden_size = 3072

View File

@ -343,7 +343,7 @@ def denoise(
updated_num_steps= len(timesteps) -1 updated_num_steps= len(timesteps) -1
if callback != None: if callback != None:
from wan.utils.loras_mutipliers import update_loras_slists from shared.utils.loras_mutipliers import update_loras_slists
update_loras_slists(model, loras_slists, updated_num_steps) update_loras_slists(model, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps)
from mmgp import offload from mmgp import offload

View File

@ -11,9 +11,9 @@ from huggingface_hub import hf_hub_download, login
from PIL import ExifTags, Image from PIL import ExifTags, Image
from safetensors.torch import load_file as load_sft from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxLoraWrapper, FluxParams from .model import Flux, FluxLoraWrapper, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder from .modules.conditioner import HFEmbedder
CHECKPOINTS_DIR = Path("checkpoints") CHECKPOINTS_DIR = Path("checkpoints")

View File

@ -0,0 +1,2 @@
from .hunyuan import HunyuanVideoSampler
from . import hunyuan_handler

View File

@ -41,9 +41,9 @@ from diffusers.utils import (
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from hyvideo.constants import PRECISION_TO_TYPE from ...constants import PRECISION_TO_TYPE
from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from hyvideo.text_encoder import TextEncoder from ...text_encoder import TextEncoder
from einops import rearrange from einops import rearrange
from ...modules import HYVideoDiffusionTransformer from ...modules import HYVideoDiffusionTransformer

View File

@ -8,24 +8,24 @@ from pathlib import Path
from einops import rearrange from einops import rearrange
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V from .constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
from hyvideo.vae import load_vae from .vae import load_vae
from hyvideo.modules import load_model from .modules import load_model
from hyvideo.text_encoder import TextEncoder from .text_encoder import TextEncoder
from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list from .utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new from .modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler from .diffusion.schedulers import FlowMatchDiscreteScheduler
from hyvideo.diffusion.pipelines import HunyuanVideoPipeline from .diffusion.pipelines import HunyuanVideoPipeline
from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline from .diffusion.pipelines import HunyuanVideoAudioPipeline
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torchvision.transforms as transforms import torchvision.transforms as transforms
import cv2 import cv2
from wan.utils.utils import calculate_new_dimensions, convert_tensor_to_image from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask from .data_kits.audio_preprocessor import encode_audio, get_facemask
from transformers import WhisperModel from transformers import WhisperModel
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
from hyvideo.data_kits.face_align import AlignImage from .data_kits.face_align import AlignImage
import librosa import librosa
def get_audio_feature(feature_extractor, audio_path, duration): def get_audio_feature(feature_extractor, audio_path, duration):
@ -66,174 +66,174 @@ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): # def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape # num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape # batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are # # 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index # special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension # # Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length # max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) # batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written # # 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence. # # Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions. # # `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. # # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 # new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] # nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding: # if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding # new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position # # 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros( # final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device # batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
) # )
final_attention_mask = torch.zeros( # final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device # batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
) # )
if labels is not None: # if labels is not None:
final_labels = torch.full( # final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device # (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
) # )
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # # set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device # target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = ( # batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device), # batch_indices.to(target_device),
non_image_indices.to(target_device), # non_image_indices.to(target_device),
text_to_overwrite.to(target_device), # text_to_overwrite.to(target_device),
) # )
attention_mask = attention_mask.to(target_device) # attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] # # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features # # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] # final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] # final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None: # if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] # final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) # # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full( # image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device # (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
) # )
image_to_overwrite[batch_indices, text_to_overwrite] = False # image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) # image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): # if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError( # raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" # f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." # f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
) # )
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) # final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite # final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) # position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. # # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) # batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices] # indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0 # final_embedding[batch_indices, indices_to_mask] = 0
if labels is None: # if labels is None:
final_labels = None # final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids # return final_embedding, final_attention_mask, final_labels, position_ids
def patched_llava_forward( # def patched_llava_forward(
self, # self,
input_ids: torch.LongTensor = None, # input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None, # pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None, # attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, # position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, # past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, # inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None, # vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None, # vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None, # labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, # use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, # output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, # output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, # return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, # cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0, # num_logits_to_keep: int = 0,
): # ):
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast # from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( # output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) # )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = ( # vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer # vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
) # )
vision_feature_select_strategy = ( # vision_feature_select_strategy = (
vision_feature_select_strategy # vision_feature_select_strategy
if vision_feature_select_strategy is not None # if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy # else self.config.vision_feature_select_strategy
) # )
if (input_ids is None) ^ (inputs_embeds is not None): # if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") # raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None: # if pixel_values is not None and inputs_embeds is not None:
raise ValueError( # raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" # "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
) # )
if inputs_embeds is None: # if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids) # inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None # image_features = None
if pixel_values is not None: # if pixel_values is not None:
image_features = self.get_image_features( # image_features = self.get_image_features(
pixel_values=pixel_values, # pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer, # vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy, # vision_feature_select_strategy=vision_feature_select_strategy,
) # )
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( # inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels # image_features, inputs_embeds, input_ids, attention_mask, labels
) # )
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) # cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
outputs = self.language_model( # outputs = self.language_model(
attention_mask=attention_mask, # attention_mask=attention_mask,
position_ids=position_ids, # position_ids=position_ids,
past_key_values=past_key_values, # past_key_values=past_key_values,
inputs_embeds=inputs_embeds, # inputs_embeds=inputs_embeds,
use_cache=use_cache, # use_cache=use_cache,
output_attentions=output_attentions, # output_attentions=output_attentions,
output_hidden_states=output_hidden_states, # output_hidden_states=output_hidden_states,
return_dict=return_dict, # return_dict=return_dict,
cache_position=cache_position, # cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep, # num_logits_to_keep=num_logits_to_keep,
) # )
logits = outputs[0] # logits = outputs[0]
loss = None # loss = None
if not return_dict: # if not return_dict:
output = (logits,) + outputs[1:] # output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output # return (loss,) + output if loss is not None else output
return LlavaCausalLMOutputWithPast( # return LlavaCausalLMOutputWithPast(
loss=loss, # loss=loss,
logits=logits, # logits=logits,
past_key_values=outputs.past_key_values, # past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, # hidden_states=outputs.hidden_states,
attentions=outputs.attentions, # attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None, # image_hidden_states=image_features if pixel_values is not None else None,
) # )
def adapt_model(model, audio_block_name): def adapt_model(model, audio_block_name):
modules_dict= { k: m for k, m in model.named_modules()} modules_dict= { k: m for k, m in model.named_modules()}
@ -320,8 +320,8 @@ class Inference(object):
device = "cuda" device = "cuda"
import transformers import transformers
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features # transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
text_len = 512 text_len = 512
@ -778,7 +778,7 @@ class HunyuanVideoSampler(Inference):
raise ValueError( raise ValueError(
f"Seed must be an integer, a list of integers, or None, got {seed}." f"Seed must be an integer, a list of integers, or None, got {seed}."
) )
from wan.utils.utils import seed_everything from shared.utils.utils import seed_everything
seed_everything(seed) seed_everything(seed)
generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
# generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
@ -956,7 +956,7 @@ class HunyuanVideoSampler(Inference):
# out_latents= ref_latents / self.vae.config.scaling_factor # out_latents= ref_latents / self.vae.config.scaling_factor
# image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0]
# image = image.clamp(-1, 1) # image = image.clamp(-1, 1)
# from wan.utils.utils import cache_video # from shared.utils.utils import cache_video
# cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
motion_pose = np.array([25] * 4) motion_pose = np.array([25] * 4)
@ -1040,5 +1040,4 @@ class HunyuanVideoSampler(Inference):
return samples return samples
def query_model_def(model_type, model_def):
return None

View File

@ -0,0 +1,147 @@
import torch
def get_hunyuan_text_encoder_filename(text_encoder_quantization):
if text_encoder_quantization =="int8":
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors"
else:
text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors"
return text_encoder_filename
class family_handler():
@staticmethod
def query_model_def(base_model_type, model_def):
extra_model_def = {}
if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
fps = 25
elif base_model_type in ["hunyuan", "hunyuan_i2v", "hunyuan_custom_edit", "hunyuan_custom"]:
fps = 24
else:
fps = 16
extra_model_def["fps"] = fps
extra_model_def["frames_minimum"] = 5
extra_model_def["frames_steps"] = 4
extra_model_def["sliding_window"] = False
extra_model_def["embedded_guidance"] = base_model_type in ["hunyuan", "hunyuan_i2v"]
extra_model_def["cfg_star"] = base_model_type in [ "hunyuan_avatar", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]
extra_model_def["skip_steps_cache"] = True
return extra_model_def
@staticmethod
def query_supported_types():
return ["hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"]
@staticmethod
def query_family_maps():
models_eqv_map = {
}
models_comp_map = {
"hunyuan_custom": ["hunyuan_custom_edit", "hunyuan_custom_audio"],
}
return models_eqv_map, models_comp_map
@staticmethod
def query_model_family():
return "hunyuan"
@staticmethod
def query_family_infos():
return {"hunyuan":(20, "Hunyuan Video")}
@staticmethod
def get_rgb_factors(model_type):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("hunyuan")
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
text_encoder_filename = get_hunyuan_text_encoder_filename(text_encoder_quantization)
return {
"repoId" : "DeepBeepMeep/HunyuanVideo",
"sourceFolderList" : [ "llava-llama-3-8b", "clip_vit_large_patch14", "whisper-tiny" , "det_align", "" ],
"fileList" :[ ["config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "preprocessor_config.json"] + computeList(text_encoder_filename) ,
["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"],
["config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"],
["detface.pt"],
[ "hunyuan_video_720_quanto_int8_map.json", "hunyuan_video_custom_VAE_fp32.safetensors", "hunyuan_video_custom_VAE_config.json", "hunyuan_video_VAE_fp32.safetensors", "hunyuan_video_VAE_config.json" , "hunyuan_video_720_quanto_int8_map.json" ] + computeList(model_filename)
]
}
@staticmethod
def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
from .hunyuan import HunyuanVideoSampler
from mmgp import offload
hunyuan_model = HunyuanVideoSampler.from_pretrained(
model_filepath = model_filename,
model_type = model_type,
base_model_type = base_model_type,
text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization),
dtype = dtype,
quantizeTransformer = quantizeTransformer,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
)
pipe = { "transformer" : hunyuan_model.model, "text_encoder" : hunyuan_model.text_encoder, "text_encoder_2" : hunyuan_model.text_encoder_2, "vae" : hunyuan_model.vae }
if hunyuan_model.wav2vec != None:
pipe["wav2vec"] = hunyuan_model.wav2vec
# if hunyuan_model.align_instance != None:
# pipe["align_instance"] = hunyuan_model.align_instance.facedet.model
from .modules.models import get_linear_split_map
split_linear_modules_map = get_linear_split_map()
hunyuan_model.model.split_linear_modules_map = split_linear_modules_map
offload.split_linear_modules(hunyuan_model.model, split_linear_modules_map )
return hunyuan_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults["embedded_guidance_scale"]= 6.0
if base_model_type in ["hunyuan","hunyuan_i2v"]:
ui_defaults.update({
"guidance_scale": 7.0,
})
elif base_model_type in ["hunyuan_custom"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"resolution": "1280x720",
"video_prompt_type": "I",
})
elif base_model_type in ["hunyuan_custom_audio"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "I",
})
elif base_model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "MVAI",
"sliding_window_size": 129,
})
elif base_model_type in ["hunyuan_avatar"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
"remove_background_images_ref": 0,
"skip_steps_start_step_perc": 25,
"video_length": 129,
"video_prompt_type": "I",
})

View File

@ -18,7 +18,7 @@ from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, appl
from .token_refiner import SingleTokenRefiner from .token_refiner import SingleTokenRefiner
import numpy as np import numpy as np
from mmgp import offload from mmgp import offload
from wan.modules.attention import pay_attention from shared.attention import pay_attention
from .audio_adapters import AudioProjNet2, PerceiverAttentionCA from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
def get_linear_split_map(): def get_linear_split_map():

View File

@ -15,6 +15,7 @@ from transformers.utils import ModelOutput
from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
from ..constants import PRECISION_TO_TYPE from ..constants import PRECISION_TO_TYPE
from .llava.modeling_llava import LlavaForConditionalGeneration
def use_default(value, default): def use_default(value, default):
@ -188,11 +189,17 @@ class TextEncoder(nn.Module):
if "llm" in text_encoder_type: if "llm" in text_encoder_type:
from mmgp import offload from mmgp import offload
forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json" # forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath) # self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath)
if forcedConfigPath != None:
if "i2v" in text_encoder_type:
self.model= offload.fast_load_transformers_model(self.model_path, modelClass= LlavaForConditionalGeneration)
else:
self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model", forcedConfigPath = "ckpts/llava-llama-3-8b/config.json")
self.model.final_layer_norm = self.model.model.norm self.model.final_layer_norm = self.model.model.norm
else: else:
self.model, self.model_path = load_text_encoder( self.model, self.model_path = load_text_encoder(
text_encoder_type=self.text_encoder_type, text_encoder_type=self.text_encoder_type,

View File

@ -0,0 +1,29 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from typing import TYPE_CHECKING
# from ...utils import _LazyModule
# from ...utils.import_utils import define_import_structure
# if TYPE_CHECKING:
# from .configuration_llava import *
# from .image_processing_llava_fast import *
# from .modeling_llava import *
# from .processing_llava import *
# else:
# import sys
# _file = globals()["__file__"]
# sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,137 @@
# coding=utf-8
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Llava model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class LlavaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llava-9B.
e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
vision_feature_layer (`Union[int, List[int]]`, *optional*, defaults to -2):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
image_seq_length (`int`, *optional*, defaults to 576):
Sequence length of one image embedding.
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the multimodal projector.
Example:
```python
>>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a Llava llava-1.5-7b style configuration
>>> configuration = LlavaConfig(vision_config, text_config)
>>> # Initializing a model from the llava-1.5-7b style configuration
>>> model = LlavaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llava"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
is_composition = True
def __init__(
self,
vision_config=None,
text_config=None,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
image_seq_length=576,
multimodal_projector_bias=True,
**kwargs,
):
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(
"vision_feature_select_strategy should be one of 'default', 'full'."
f"Got: {vision_feature_select_strategy}"
)
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
if isinstance(vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
)
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["clip_vision_model"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=336,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config
self.multimodal_projector_bias = multimodal_projector_bias
super().__init__(**kwargs)
__all__ = ["LlavaConfig"]

View File

@ -0,0 +1,436 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for LLaVa."""
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
class LlavaImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa image processor.
Args:
do_pad (`bool`, *optional*, defaults to `False`):
Whether to pad the image to a square based on the longest edge.
The padding value is determined by the `image_mean` parameter.
Can be overridden by `do_pad` in the `preprocess` method.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to 224):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_pad: bool = False,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_pad = do_pad
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"do_pad",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
def pad_to_square(
self,
image: np.ndarray,
background_color: Union[int, Tuple[int, int, int]] = 0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
image = (
to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
)
return image
# Copied from transformers.models.clip.image_processing_clip.CLIPImageProcessor.resize
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_pad: Optional[bool] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[int] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to a square based on the longest edge.
The padding value is determined by the `image_mean` parameter.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_pad = do_pad if do_pad is not None else self.do_pad
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
# we don't pass `do_pad` here since LLaVa uses a custom padding to a square
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
processed_images = []
for image in images:
if do_pad:
image = self.pad_to_square(
image=image,
background_color=tuple(int(x * 255) for x in self.image_mean),
input_data_format=input_data_format,
)
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
processed_images.append(image)
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
__all__ = ["LlavaImageProcessor"]

View File

@ -0,0 +1,201 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for LLaVa."""
from typing import List, Optional, Tuple, Union
from ...image_processing_utils import (
BatchFeature,
)
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
)
if is_vision_available():
from ...image_utils import PILImageResampling
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
class LlavaFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
do_pad: Optional[bool]
@add_start_docstrings(
"Constructs a fast Llava image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
""",
)
class LlavaImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
image_std = OPENAI_CLIP_STD
size = {"shortest_edge": 224}
default_to_square = False
crop_size = {"height": 224, "width": 224}
do_pad = False
do_resize = True
do_center_crop = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
valid_kwargs = LlavaFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> None:
super().__init__(**kwargs)
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image to a square based on the longest edge. Can be overridden by the `do_pad` parameter
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaFastImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_pad: bool,
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_pad:
stacked_images = self.pad_to_square(
images=stacked_images, background_color=tuple(int(x * 255) for x in self.image_mean)
)
resized_images_grouped[shape] = stacked_images
padded_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for batched resizing
# Needed in case do_pad is False, or padding returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(padded_images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
__all__ = ["LlavaImageProcessorFast"]

View File

@ -0,0 +1,531 @@
# coding=utf-8
# Copyright 2023 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Llava model."""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.models.auto import AutoModel, AutoModelForCausalLM
from .configuration_llava import LlavaConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
@dataclass
class LlavaCausalLMOutputWithPast(ModelOutput):
"""
Base class for Llava causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
class LlavaMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig):
super().__init__()
# We have hidden_size * the number of vision feature layers
num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * num_feature_layers,
config.text_config.hidden_size,
bias=config.multimodal_projector_bias,
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
LLAVA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_START_DOCSTRING,
)
class LlavaPreTrainedModel(PreTrainedModel):
config_class = LlavaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
LLAVA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"""The LLAVA model which consists of a vision backbone and a language model.""",
LLAVA_START_DOCSTRING,
)
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
if vision_feature_select_strategy not in ["default", "full"]:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
else:
hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
# For default; crop CLS from each hidden state in the hidden state pool
if vision_feature_select_strategy == "default":
hs_pool = [hs[:, 1:] for hs in hs_pool]
selected_image_feature = torch.cat(hs_pool, dim=-1)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
# @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
# @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
# @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
):
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
logits = outputs[0]
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return LlavaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
return model_inputs
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"]

View File

@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Llava.
"""
from typing import List, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
logger = logging.get_logger(__name__)
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
}
class LlavaProcessor(ProcessorMixin):
r"""
Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
[`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
Args:
image_processor ([`LlavaImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
patch_size (`int`, *optional*):
Patch size from the vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Shoudl be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
num_additional_image_tokens (`int`, *optional*, defaults to 0):
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
extra tokens appended, no need to set this arg.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"patch_size",
"vision_feature_select_strategy",
"image_token",
"num_additional_image_tokens",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
patch_size=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
num_additional_image_tokens=0,
**kwargs,
):
self.patch_size = patch_size
self.num_additional_image_tokens = num_additional_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[LlavaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None and text is None:
raise ValueError("You have to specify at least one of `images` or `text`.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
output_kwargs = self._merge_kwargs(
LlavaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
num_image_tokens = (height // self.patch_size) * (
width // self.patch_size
) + self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
prompt_strings = []
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["LlavaProcessor"]

View File

@ -0,0 +1,2 @@
from .ltxv import LTXV
from . import ltxv_handler

View File

@ -7,7 +7,7 @@ from pathlib import Path
from diffusers.utils import logging from diffusers.utils import logging
from typing import Optional, List, Union from typing import Optional, List, Union
import yaml import yaml
from wan.utils.utils import calculate_new_dimensions from shared.utils.utils import calculate_new_dimensions
import imageio import imageio
import json import json
import numpy as np import numpy as np
@ -605,16 +605,4 @@ def load_media_file(
raise Exception("video format not supported") raise Exception("video format not supported")
return media_tensor return media_tensor
def query_model_def(model_type, model_def):
LTXV_config = model_def.get("LTXV_config", "")
distilled= "distilled" in LTXV_config
model_def_output = {
"no_guidance": True,
}
if distilled:
model_def_output.update({
"lock_inference_steps": True,
"no_negative_prompt" : True,
})
return model_def_output

View File

@ -0,0 +1,92 @@
import torch
def get_ltxv_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors"
if text_encoder_quantization =="int8":
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
return text_encoder_filename
class family_handler():
@staticmethod
def query_model_def(base_model_type, model_def):
LTXV_config = model_def.get("LTXV_config", "")
distilled= "distilled" in LTXV_config
extra_model_def = {
"no_guidance": True,
}
if distilled:
extra_model_def.update({
"lock_inference_steps": True,
"no_negative_prompt" : True,
})
extra_model_def["fps"] = 30
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 8
extra_model_def["sliding_window"] = True
return extra_model_def
@staticmethod
def query_supported_types():
return ["ltxv_13B"]
@staticmethod
def query_family_maps():
return {}, {}
@staticmethod
def get_rgb_factors(model_type):
from shared.RGB_factors import get_rgb_factors
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv")
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_family():
return "ltxv"
@staticmethod
def query_family_infos():
return {"ltxv":(10, "LTX Video")}
@staticmethod
def get_vae_block_size(base_model_type):
return 32
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization)
return {
"repoId" : "DeepBeepMeep/LTX_Video",
"sourceFolderList" : ["T5_xxl_1.1", "" ],
"fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ]
}
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
from .ltxv import LTXV
ltxv_model = LTXV(
model_filepath = model_filename,
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
model_type = model_type,
base_model_type = base_model_type,
model_def = model_def,
dtype = dtype,
# quantizeTransformer = quantizeTransformer,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer
)
pipeline = ltxv_model.pipeline
pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler}
return ltxv_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
pass

View File

@ -15,12 +15,12 @@ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbedding
from safetensors import safe_open from safetensors import safe_open
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd from ..autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm from ...models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND from ...models.autoencoders.pixel_shuffle import PixelShuffleND
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper from ...models.autoencoders.vae import AutoencoderKLWrapper
from ltx_video.models.transformers.attention import Attention from ...models.transformers.attention import Attention
from ltx_video.utils.diffusers_config_mapping import ( from ...utils.diffusers_config_mapping import (
diffusers_and_ours_config_mapping, diffusers_and_ours_config_mapping,
make_hashable_key, make_hashable_key,
VAE_KEYS_RENAME_DICT, VAE_KEYS_RENAME_DICT,

View File

@ -2,8 +2,8 @@ from typing import Tuple, Union
import torch import torch
from ltx_video.models.autoencoders.dual_conv3d import DualConv3d from ..autoencoders.dual_conv3d import DualConv3d
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d from ..autoencoders.causal_conv3d import CausalConv3d
def make_conv_nd( def make_conv_nd(

View File

@ -9,7 +9,7 @@ from einops import rearrange
from diffusers import ConfigMixin, ModelMixin from diffusers import ConfigMixin, ModelMixin
from safetensors.torch import safe_open from safetensors.torch import safe_open
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND from ...models.autoencoders.pixel_shuffle import PixelShuffleND
class ResBlock(nn.Module): class ResBlock(nn.Module):

View File

@ -10,7 +10,7 @@ from diffusers.models.autoencoders.vae import (
DiagonalGaussianDistribution, DiagonalGaussianDistribution,
) )
from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_outputs import AutoencoderKLOutput
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd from ...models.autoencoders.conv_nd_factory import make_conv_nd
class AutoencoderKLWrapper(ModelMixin, ConfigMixin): class AutoencoderKLWrapper(ModelMixin, ConfigMixin):

View File

@ -5,10 +5,10 @@ from einops import rearrange
from torch import Tensor from torch import Tensor
from ltx_video.models.autoencoders.causal_video_autoencoder import ( from ...models.autoencoders.causal_video_autoencoder import (
CausalVideoAutoencoder, CausalVideoAutoencoder,
) )
from ltx_video.models.autoencoders.video_autoencoder import ( from ...models.autoencoders.video_autoencoder import (
Downsample3D, Downsample3D,
VideoAutoencoder, VideoAutoencoder,
) )

View File

@ -11,10 +11,10 @@ from torch.nn import functional
from diffusers.utils import logging from diffusers.utils import logging
from ltx_video.utils.torch_utils import Identity from ...utils.torch_utils import Identity
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd from ...models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
from ltx_video.models.autoencoders.pixel_norm import PixelNorm from ...models.autoencoders.pixel_norm import PixelNorm
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper from ...models.autoencoders.vae import AutoencoderKLWrapper
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -19,15 +19,9 @@ from diffusers.utils import deprecate, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from wan.modules.attention import pay_attention from shared.attention import pay_attention
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...utils.skip_layer_strategy import SkipLayerStrategy
try:
from torch_xla.experimental.custom_kernel import flash_attention
except ImportError:
# workaround for automatic tests. Currently this function is manually patched
# to the torch_xla lib on setup of container
pass
# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py

View File

@ -16,10 +16,10 @@ from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils import logging from diffusers.utils import logging
from torch import nn from torch import nn
from safetensors import safe_open from safetensors import safe_open
from ltx_video.models.transformers.attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape from .attention import BasicTransformerBlock, reshape_hidden_states, restore_hidden_states_shape
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy from ...utils.skip_layer_strategy import SkipLayerStrategy
from ltx_video.utils.diffusers_config_mapping import ( from ...utils.diffusers_config_mapping import (
diffusers_and_ours_config_mapping, diffusers_and_ours_config_mapping,
make_hashable_key, make_hashable_key,
TRANSFORMER_KEYS_RENAME_DICT, TRANSFORMER_KEYS_RENAME_DICT,

Some files were not shown because too many files have changed in this diff Show More