mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			223 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			223 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
 | 
						|
from mmgp import offload
 | 
						|
import inspect
 | 
						|
from typing import Any, Callable, Dict, List, Optional, Union
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch, json, os
 | 
						|
import math
 | 
						|
 | 
						|
from diffusers.image_processor import VaeImageProcessor
 | 
						|
from .transformer_qwenimage import QwenImageTransformer2DModel
 | 
						|
 | 
						|
from diffusers.utils import logging, replace_example_docstring
 | 
						|
from diffusers.utils.torch_utils import randn_tensor
 | 
						|
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, AutoTokenizer, Qwen2VLProcessor
 | 
						|
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
 | 
						|
from diffusers import FlowMatchEulerDiscreteScheduler
 | 
						|
from .pipeline_qwenimage import QwenImagePipeline
 | 
						|
from PIL import Image
 | 
						|
from shared.utils.utils import calculate_new_dimensions
 | 
						|
 | 
						|
def stitch_images(img1, img2):
 | 
						|
    # Resize img2 to match img1's height
 | 
						|
    width1, height1 = img1.size
 | 
						|
    width2, height2 = img2.size
 | 
						|
    new_width2 = int(width2 * height1 / height2)
 | 
						|
    img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
 | 
						|
    
 | 
						|
    stitched = Image.new('RGB', (width1 + new_width2, height1))
 | 
						|
    stitched.paste(img1, (0, 0))
 | 
						|
    stitched.paste(img2_resized, (width1, 0))
 | 
						|
    return stitched
 | 
						|
 | 
						|
class model_factory():
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        checkpoint_dir,
 | 
						|
        model_filename = None,
 | 
						|
        model_type = None, 
 | 
						|
        model_def = None,
 | 
						|
        base_model_type = None,
 | 
						|
        text_encoder_filename = None,
 | 
						|
        quantizeTransformer = False,
 | 
						|
        save_quantized = False,
 | 
						|
        dtype = torch.bfloat16,
 | 
						|
        VAE_dtype = torch.float32,
 | 
						|
        mixed_precision_transformer = False,
 | 
						|
    ):
 | 
						|
    
 | 
						|
 | 
						|
        transformer_filename = model_filename[0]
 | 
						|
        processor = None
 | 
						|
        tokenizer = None
 | 
						|
        if base_model_type == "qwen_image_edit_20B":
 | 
						|
            processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
 | 
						|
        tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
 | 
						|
 | 
						|
 | 
						|
        base_config_file = "configs/qwen_image_20B.json" 
 | 
						|
        with open(base_config_file, 'r', encoding='utf-8') as f:
 | 
						|
            transformer_config = json.load(f)
 | 
						|
        transformer_config.pop("_diffusers_version")
 | 
						|
        transformer_config.pop("_class_name")
 | 
						|
        transformer_config.pop("pooled_projection_dim")
 | 
						|
        
 | 
						|
        from accelerate import init_empty_weights
 | 
						|
        with init_empty_weights():
 | 
						|
            transformer = QwenImageTransformer2DModel(**transformer_config)
 | 
						|
        source =  model_def.get("source", None)
 | 
						|
 | 
						|
        if source is not None:
 | 
						|
            offload.load_model_data(transformer, source)
 | 
						|
        else:
 | 
						|
            offload.load_model_data(transformer, transformer_filename)
 | 
						|
        # transformer = offload.fast_load_transformers_model("transformer_quanto.safetensors", writable_tensors= True , modelClass=QwenImageTransformer2DModel, defaultConfigPath="transformer_config.json")
 | 
						|
 | 
						|
        if not source is None:
 | 
						|
            from wgp import save_model
 | 
						|
            save_model(transformer, model_type, dtype, None)
 | 
						|
 | 
						|
        if save_quantized:
 | 
						|
            from wgp import save_quantized_model
 | 
						|
            save_quantized_model(transformer, model_type, model_filename[0], dtype, base_config_file)
 | 
						|
 | 
						|
        text_encoder = offload.fast_load_transformers_model(text_encoder_filename,  writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration,  defaultConfigPath= os.path.join(checkpoint_dir, "Qwen2.5-VL-7B-Instruct", "config.json"))
 | 
						|
        # text_encoder = offload.fast_load_transformers_model(text_encoder_filename, do_quantize=True,  writable_tensors= True , modelClass=Qwen2_5_VLForConditionalGeneration, defaultConfigPath="text_encoder_config.json", verboseLevel=2)
 | 
						|
        # text_encoder.to(torch.float16)
 | 
						|
        # offload.save_model(text_encoder, "text_encoder_quanto_fp16.safetensors", do_quantize= True)
 | 
						|
 | 
						|
        vae = offload.fast_load_transformers_model( os.path.join(checkpoint_dir,"qwen_vae.safetensors"), writable_tensors= True , modelClass=AutoencoderKLQwenImage, defaultConfigPath=os.path.join(checkpoint_dir,"qwen_vae_config.json"))
 | 
						|
        
 | 
						|
        self.pipeline = QwenImagePipeline(vae, text_encoder, tokenizer, transformer, processor)
 | 
						|
        self.vae=vae
 | 
						|
        self.text_encoder=text_encoder
 | 
						|
        self.tokenizer=tokenizer
 | 
						|
        self.transformer=transformer
 | 
						|
        self.processor = processor
 | 
						|
 | 
						|
    def generate(
 | 
						|
        self,
 | 
						|
        seed: int | None = None,
 | 
						|
        input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
 | 
						|
        n_prompt = None,
 | 
						|
        sampling_steps: int = 20,
 | 
						|
        input_ref_images = None,
 | 
						|
        image_guide= None,
 | 
						|
        image_mask= None,
 | 
						|
        width= 832,
 | 
						|
        height=480,
 | 
						|
        guide_scale: float = 4,
 | 
						|
        fit_into_canvas = None,
 | 
						|
        callback = None,
 | 
						|
        loras_slists = None,
 | 
						|
        batch_size = 1,
 | 
						|
        video_prompt_type = "",
 | 
						|
        VAE_tile_size = None, 
 | 
						|
        joint_pass = True,
 | 
						|
        sample_solver='default',
 | 
						|
        denoising_strength = 1.,
 | 
						|
        model_mode = 0,
 | 
						|
        outpainting_dims = None,
 | 
						|
        **bbargs
 | 
						|
    ):
 | 
						|
        # Generate with different aspect ratios
 | 
						|
        aspect_ratios = {
 | 
						|
        "1:1": (1328, 1328),
 | 
						|
        "16:9": (1664, 928),
 | 
						|
        "9:16": (928, 1664),
 | 
						|
        "4:3": (1472, 1140),
 | 
						|
        "3:4": (1140, 1472)
 | 
						|
        }
 | 
						|
        
 | 
						|
 | 
						|
        if sample_solver =='lightning':
 | 
						|
            scheduler_config = {
 | 
						|
                "base_image_seq_len": 256,
 | 
						|
                "base_shift": math.log(3),  # We use shift=3 in distillation
 | 
						|
                "invert_sigmas": False,
 | 
						|
                "max_image_seq_len": 8192,
 | 
						|
                "max_shift": math.log(3),  # We use shift=3 in distillation
 | 
						|
                "num_train_timesteps": 1000,
 | 
						|
                "shift": 1.0,
 | 
						|
                "shift_terminal": None,  # set shift_terminal to None
 | 
						|
                "stochastic_sampling": False,
 | 
						|
                "time_shift_type": "exponential",
 | 
						|
                "use_beta_sigmas": False,
 | 
						|
                "use_dynamic_shifting": True,
 | 
						|
                "use_exponential_sigmas": False,
 | 
						|
                "use_karras_sigmas": False,
 | 
						|
            }
 | 
						|
        else:
 | 
						|
            scheduler_config = {
 | 
						|
                "base_image_seq_len": 256,
 | 
						|
                "base_shift": 0.5,
 | 
						|
                "invert_sigmas": False,
 | 
						|
                "max_image_seq_len": 8192,
 | 
						|
                "max_shift": 0.9,
 | 
						|
                "num_train_timesteps": 1000,
 | 
						|
                "shift": 1.0,
 | 
						|
                "shift_terminal": 0.02,
 | 
						|
                "stochastic_sampling": False,
 | 
						|
                "time_shift_type": "exponential",
 | 
						|
                "use_beta_sigmas": False,
 | 
						|
                "use_dynamic_shifting": True,
 | 
						|
                "use_exponential_sigmas": False,
 | 
						|
                "use_karras_sigmas": False
 | 
						|
            }
 | 
						|
 | 
						|
        self.scheduler=FlowMatchEulerDiscreteScheduler(**scheduler_config)
 | 
						|
        self.pipeline.scheduler = self.scheduler 
 | 
						|
        if VAE_tile_size is not None:
 | 
						|
            self.vae.use_tiling  = VAE_tile_size[0] 
 | 
						|
            self.vae.tile_latent_min_height  = VAE_tile_size[1] 
 | 
						|
            self.vae.tile_latent_min_width  = VAE_tile_size[1]
 | 
						|
 | 
						|
 | 
						|
        self.vae.enable_slicing()
 | 
						|
        # width, height = aspect_ratios["16:9"]
 | 
						|
 | 
						|
        if n_prompt is None or len(n_prompt) == 0:
 | 
						|
            n_prompt=  "text, watermark, copyright, blurry, low resolution"
 | 
						|
        if image_guide is not None:
 | 
						|
            input_ref_images = [image_guide] 
 | 
						|
        elif input_ref_images is not None:
 | 
						|
            # image stiching method
 | 
						|
            stiched = input_ref_images[0]
 | 
						|
            if "K" in video_prompt_type :
 | 
						|
                w, h = input_ref_images[0].size
 | 
						|
                height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
						|
 | 
						|
            for new_img in input_ref_images[1:]:
 | 
						|
                stiched = stitch_images(stiched, new_img)
 | 
						|
            input_ref_images  = [stiched]
 | 
						|
 | 
						|
        image = self.pipeline(
 | 
						|
            prompt=input_prompt,
 | 
						|
            negative_prompt=n_prompt,
 | 
						|
            image = input_ref_images,
 | 
						|
            image_mask = image_mask,
 | 
						|
            width=width,
 | 
						|
            height=height,
 | 
						|
            num_inference_steps=sampling_steps,
 | 
						|
            num_images_per_prompt = batch_size,
 | 
						|
            true_cfg_scale=guide_scale,
 | 
						|
            callback = callback,
 | 
						|
            pipeline=self,
 | 
						|
            loras_slists=loras_slists,
 | 
						|
            joint_pass = joint_pass,
 | 
						|
            denoising_strength=denoising_strength,
 | 
						|
            generator=torch.Generator(device="cuda").manual_seed(seed),
 | 
						|
            lora_inpaint = image_mask is not None and model_mode == 1,
 | 
						|
            outpainting_dims = outpainting_dims,
 | 
						|
        )        
 | 
						|
        if image is None: return None
 | 
						|
        return image.transpose(0, 1)
 | 
						|
 | 
						|
    def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
 | 
						|
        if model_mode == 0: return [], []
 | 
						|
        preloadURLs = get_model_recursive_prop(model_type,  "preload_URLs")
 | 
						|
        return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
 | 
						|
 | 
						|
 |