Wan2.1/models/qwen/qwen_main.py
2025-08-09 20:12:51 +02:00

129 lines
5.1 KiB
Python

from mmgp import offload
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch, json, os
from diffusers.image_processor import VaeImageProcessor
from .transformer_qwenimage import QwenImageTransformer2DModel
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler
from .pipeline_qwenimage import QwenImagePipeline
class model_factory():
def __init__(
self,
checkpoint_dir,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
with open( os.path.join(checkpoint_dir, "qwen_scheduler_config.json"), 'r', encoding='utf-8') as f:
scheduler_config = json.load(f)
scheduler_config.pop("_class_name")
scheduler_config.pop("_diffusers_version")
scheduler = FlowMatchEulerDiscreteScheduler(**scheduler_config)
transformer_filename = model_filename[0]
tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
with open("configs/qwen_image_20B.json", 'r', encoding='utf-8') as f:
transformer_config = json.load(f)
transformer_config.pop("_diffusers_version")
transformer_config.pop("_class_name")
transformer_config.pop("pooled_projection_dim")
from accelerate import init_empty_weights
with init_empty_weights():
transformer = QwenImageTransformer2DModel(**transformer_config)
offload.load_model_data(transformer, transformer_filename)
# transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json")
text_encoder = offload.fast_load_transformers_model(text_encoder_filename, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json"))
# text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2)
# text_encoder.to(torch.float16)
# offload.save_model(text_encoder, "text_encoder_quanto_fp16.safetensors", do_quantize= True)
vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json"))
self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer, scheduler)
self.vae=vae
self.text_encoder=text_encoder
self.tokenizer=tokenizer
self.transformer=transformer
self.scheduler=scheduler
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt = None,
sampling_steps: int = 20,
input_ref_images = None,
width= 832,
height=480,
embedded_guidance_scale: float = 4,
fit_into_canvas = None,
callback = None,
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
VAE_tile_size = None,
joint_pass = True,
**bbargs
):
# Generate with different aspect ratios
aspect_ratios = {
"1:1": (1328, 1328),
"16:9": (1664, 928),
"9:16": (928, 1664),
"4:3": (1472, 1140),
"3:4": (1140, 1472)
}
if VAE_tile_size is not None:
self.vae.use_tiling = VAE_tile_size[0]
self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1]
self.vae.enable_slicing()
# width, height = aspect_ratios["16:9"]
if n_prompt is None or len(n_prompt) == 0:
n_prompt= "text, watermark, copyright, blurry, low resolution"
image = self.pipeline(
prompt=input_prompt,
negative_prompt=n_prompt,
width=width,
height=height,
num_inference_steps=sampling_steps,
num_images_per_prompt = batch_size,
true_cfg_scale=embedded_guidance_scale,
callback = callback,
pipeline=self,
loras_slists=loras_slists,
joint_pass = joint_pass,
generator=torch.Generator(device="cuda").manual_seed(seed)
)
if image is None: return None
return image.transpose(0, 1)