mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
223 lines
8.8 KiB
Python
223 lines
8.8 KiB
Python
|
|
from mmgp import offload
|
|
import inspect
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch, json, os
|
|
import math
|
|
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
from .transformer_qwenimage import QwenImageTransformer2DModel
|
|
|
|
from diffusers.utils import logging, replace_example_docstring
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer, Qwen2VLProcessor
|
|
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
from .pipeline_qwenimage import QwenImagePipeline
|
|
from PIL import Image
|
|
from shared.utils.utils import calculate_new_dimensions
|
|
|
|
def stitch_images(img1, img2):
|
|
# Resize img2 to match img1's height
|
|
width1, height1 = img1.size
|
|
width2, height2 = img2.size
|
|
new_width2 = int(width2 * height1 / height2)
|
|
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
|
|
|
stitched = Image.new('RGB', (width1 + new_width2, height1))
|
|
stitched.paste(img1, (0, 0))
|
|
stitched.paste(img2_resized, (width1, 0))
|
|
return stitched
|
|
|
|
class model_factory():
|
|
def __init__(
|
|
self,
|
|
checkpoint_dir,
|
|
model_filename = None,
|
|
model_type = None,
|
|
model_def = None,
|
|
base_model_type = None,
|
|
text_encoder_filename = None,
|
|
quantizeTransformer = False,
|
|
save_quantized = False,
|
|
dtype = torch.bfloat16,
|
|
VAE_dtype = torch.float32,
|
|
mixed_precision_transformer = False,
|
|
):
|
|
|
|
|
|
transformer_filename = model_filename[0]
|
|
processor = None
|
|
tokenizer = None
|
|
if base_model_type == "qwen_image_edit_20B":
|
|
processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
|
|
tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
|
|
|
|
|
|
base_config_file = "configs/qwen_image_20B.json"
|
|
with open(base_config_file, 'r', encoding='utf-8') as f:
|
|
transformer_config = json.load(f)
|
|
transformer_config.pop("_diffusers_version")
|
|
transformer_config.pop("_class_name")
|
|
transformer_config.pop("pooled_projection_dim")
|
|
|
|
from accelerate import init_empty_weights
|
|
with init_empty_weights():
|
|
transformer = QwenImageTransformer2DModel(**transformer_config)
|
|
source = model_def.get("source", None)
|
|
|
|
if source is not None:
|
|
offload.load_model_data(transformer, source)
|
|
else:
|
|
offload.load_model_data(transformer, transformer_filename)
|
|
# transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json")
|
|
|
|
if not source is None:
|
|
from wgp import save_model
|
|
save_model(transformer, model_type, dtype, None)
|
|
|
|
if save_quantized:
|
|
from wgp import save_quantized_model
|
|
save_quantized_model(transformer, model_type, model_filename[0], dtype, base_config_file)
|
|
|
|
text_encoder = offload.fast_load_transformers_model(text_encoder_filename, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json"))
|
|
# text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True, writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2)
|
|
# text_encoder.to(torch.float16)
|
|
# offload.save_model(text_encoder, "text_encoder_quanto_fp16.safetensors", do_quantize= True)
|
|
|
|
vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json"))
|
|
|
|
self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer, processor)
|
|
self.vae=vae
|
|
self.text_encoder=text_encoder
|
|
self.tokenizer=tokenizer
|
|
self.transformer=transformer
|
|
self.processor = processor
|
|
|
|
def generate(
|
|
self,
|
|
seed: int | None = None,
|
|
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
|
|
n_prompt = None,
|
|
sampling_steps: int = 20,
|
|
input_ref_images = None,
|
|
image_guide= None,
|
|
image_mask= None,
|
|
width= 832,
|
|
height=480,
|
|
guide_scale: float = 4,
|
|
fit_into_canvas = None,
|
|
callback = None,
|
|
loras_slists = None,
|
|
batch_size = 1,
|
|
video_prompt_type = "",
|
|
VAE_tile_size = None,
|
|
joint_pass = True,
|
|
sample_solver='default',
|
|
denoising_strength = 1.,
|
|
model_mode = 0,
|
|
outpainting_dims = None,
|
|
**bbargs
|
|
):
|
|
# Generate with different aspect ratios
|
|
aspect_ratios = {
|
|
"1:1": (1328, 1328),
|
|
"16:9": (1664, 928),
|
|
"9:16": (928, 1664),
|
|
"4:3": (1472, 1140),
|
|
"3:4": (1140, 1472)
|
|
}
|
|
|
|
|
|
if sample_solver =='lightning':
|
|
scheduler_config = {
|
|
"base_image_seq_len": 256,
|
|
"base_shift": math.log(3), # We use shift=3 in distillation
|
|
"invert_sigmas": False,
|
|
"max_image_seq_len": 8192,
|
|
"max_shift": math.log(3), # We use shift=3 in distillation
|
|
"num_train_timesteps": 1000,
|
|
"shift": 1.0,
|
|
"shift_terminal": None, # set shift_terminal to None
|
|
"stochastic_sampling": False,
|
|
"time_shift_type": "exponential",
|
|
"use_beta_sigmas": False,
|
|
"use_dynamic_shifting": True,
|
|
"use_exponential_sigmas": False,
|
|
"use_karras_sigmas": False,
|
|
}
|
|
else:
|
|
scheduler_config = {
|
|
"base_image_seq_len": 256,
|
|
"base_shift": 0.5,
|
|
"invert_sigmas": False,
|
|
"max_image_seq_len": 8192,
|
|
"max_shift": 0.9,
|
|
"num_train_timesteps": 1000,
|
|
"shift": 1.0,
|
|
"shift_terminal": 0.02,
|
|
"stochastic_sampling": False,
|
|
"time_shift_type": "exponential",
|
|
"use_beta_sigmas": False,
|
|
"use_dynamic_shifting": True,
|
|
"use_exponential_sigmas": False,
|
|
"use_karras_sigmas": False
|
|
}
|
|
|
|
self.scheduler=FlowMatchEulerDiscreteScheduler(**scheduler_config)
|
|
self.pipeline.scheduler = self.scheduler
|
|
if VAE_tile_size is not None:
|
|
self.vae.use_tiling = VAE_tile_size[0]
|
|
self.vae.tile_latent_min_height = VAE_tile_size[1]
|
|
self.vae.tile_latent_min_width = VAE_tile_size[1]
|
|
|
|
|
|
self.vae.enable_slicing()
|
|
# width, height = aspect_ratios["16:9"]
|
|
|
|
if n_prompt is None or len(n_prompt) == 0:
|
|
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
|
if image_guide is not None:
|
|
input_ref_images = [image_guide]
|
|
elif input_ref_images is not None:
|
|
# image stiching method
|
|
stiched = input_ref_images[0]
|
|
if "K" in video_prompt_type :
|
|
w, h = input_ref_images[0].size
|
|
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
|
|
|
for new_img in input_ref_images[1:]:
|
|
stiched = stitch_images(stiched, new_img)
|
|
input_ref_images = [stiched]
|
|
|
|
image = self.pipeline(
|
|
prompt=input_prompt,
|
|
negative_prompt=n_prompt,
|
|
image = input_ref_images,
|
|
image_mask = image_mask,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=sampling_steps,
|
|
num_images_per_prompt = batch_size,
|
|
true_cfg_scale=guide_scale,
|
|
callback = callback,
|
|
pipeline=self,
|
|
loras_slists=loras_slists,
|
|
joint_pass = joint_pass,
|
|
denoising_strength=denoising_strength,
|
|
generator=torch.Generator(device="cuda").manual_seed(seed),
|
|
lora_inpaint = image_mask is not None and model_mode == 1,
|
|
outpainting_dims = outpainting_dims,
|
|
)
|
|
if image is None: return None
|
|
return image.transpose(0, 1)
|
|
|
|
def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
|
|
if model_mode == 0: return [], []
|
|
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
|
|
return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
|
|
|
|
|