mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			212 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			212 lines
		
	
	
		
			8.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import re
 | 
						|
import time
 | 
						|
from dataclasses import dataclass
 | 
						|
from glob import iglob
 | 
						|
from mmgp import offload as offload
 | 
						|
import torch
 | 
						|
from shared.utils.utils import calculate_new_dimensions
 | 
						|
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
 | 
						|
from .modules.layers import get_linear_split_map
 | 
						|
from transformers import SiglipVisionModel, SiglipImageProcessor
 | 
						|
import torchvision.transforms.functional as TVF
 | 
						|
import math
 | 
						|
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
 | 
						|
 | 
						|
from .util import (
 | 
						|
    aspect_ratio_to_height_width,
 | 
						|
    load_ae,
 | 
						|
    load_clip,
 | 
						|
    load_flow_model,
 | 
						|
    load_t5,
 | 
						|
    save_image,
 | 
						|
)
 | 
						|
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
def stitch_images(img1, img2):
 | 
						|
    # Resize img2 to match img1's height
 | 
						|
    width1, height1 = img1.size
 | 
						|
    width2, height2 = img2.size
 | 
						|
    new_width2 = int(width2 * height1 / height2)
 | 
						|
    img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
 | 
						|
    
 | 
						|
    stitched = Image.new('RGB', (width1 + new_width2, height1))
 | 
						|
    stitched.paste(img1, (0, 0))
 | 
						|
    stitched.paste(img2_resized, (width1, 0))
 | 
						|
    return stitched
 | 
						|
 | 
						|
class model_factory:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        checkpoint_dir,
 | 
						|
        model_filename = None,
 | 
						|
        model_type = None, 
 | 
						|
        model_def = None,
 | 
						|
        base_model_type = None,
 | 
						|
        text_encoder_filename = None,
 | 
						|
        quantizeTransformer = False,
 | 
						|
        save_quantized = False,
 | 
						|
        dtype = torch.bfloat16,
 | 
						|
        VAE_dtype = torch.float32,
 | 
						|
        mixed_precision_transformer = False
 | 
						|
    ):
 | 
						|
        self.device = torch.device(f"cuda")
 | 
						|
        self.VAE_dtype = VAE_dtype
 | 
						|
        self.dtype = dtype
 | 
						|
        torch_device = "cpu"
 | 
						|
        self.guidance_max_phases = model_def.get("guidance_max_phases", 0) 
 | 
						|
 | 
						|
        # model_filename = ["c:/temp/flux1-schnell.safetensors"] 
 | 
						|
        
 | 
						|
        self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
 | 
						|
        self.clip = load_clip(torch_device)
 | 
						|
        self.name = model_def.get("flux-model", "flux-dev")
 | 
						|
        # self.name= "flux-dev-kontext"
 | 
						|
        # self.name= "flux-dev"
 | 
						|
        # self.name= "flux-schnell"
 | 
						|
        source =  model_def.get("source", None)
 | 
						|
        self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
 | 
						|
 | 
						|
        self.vae = load_ae(self.name, device=torch_device)
 | 
						|
 | 
						|
        siglip_processor = siglip_model = feature_embedder = None
 | 
						|
        if self.name == 'flux-dev-uso':
 | 
						|
            siglip_path = "ckpts/siglip-so400m-patch14-384"
 | 
						|
            siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
 | 
						|
            siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
 | 
						|
            siglip_model.eval().to("cpu")
 | 
						|
            if len(model_filename) > 1:
 | 
						|
                from .modules.layers import SigLIPMultiFeatProjModel                
 | 
						|
                feature_embedder = SigLIPMultiFeatProjModel(
 | 
						|
                    siglip_token_nums=729,
 | 
						|
                    style_token_nums=64,
 | 
						|
                    siglip_token_dims=1152,
 | 
						|
                    hidden_size=3072, #self.hidden_size,
 | 
						|
                    context_layer_norm=True,
 | 
						|
                )
 | 
						|
                offload.load_model_data(feature_embedder, model_filename[1])
 | 
						|
        self.vision_encoder = siglip_model
 | 
						|
        self.vision_encoder_processor = siglip_processor
 | 
						|
        self.feature_embedder = feature_embedder
 | 
						|
   
 | 
						|
        # offload.change_dtype(self.model, dtype, True)
 | 
						|
        # offload.save_model(self.model, "flux-dev.safetensors")
 | 
						|
 | 
						|
        if not source is None:
 | 
						|
            from wgp import save_model
 | 
						|
            save_model(self.model, model_type, dtype, None)
 | 
						|
 | 
						|
        if save_quantized:
 | 
						|
            from wgp import save_quantized_model
 | 
						|
            save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
 | 
						|
 | 
						|
        split_linear_modules_map = get_linear_split_map()
 | 
						|
        self.model.split_linear_modules_map = split_linear_modules_map
 | 
						|
        offload.split_linear_modules(self.model, split_linear_modules_map )
 | 
						|
 | 
						|
    
 | 
						|
    def generate(
 | 
						|
            self,
 | 
						|
            seed: int | None = None,
 | 
						|
            input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
 | 
						|
            n_prompt: str = None,
 | 
						|
            sampling_steps: int = 20,
 | 
						|
            input_ref_images = None,
 | 
						|
            width= 832,
 | 
						|
            height=480,
 | 
						|
            embedded_guidance_scale: float = 2.5,
 | 
						|
            guide_scale = 2.5,
 | 
						|
            fit_into_canvas = None,
 | 
						|
            callback = None,
 | 
						|
            loras_slists = None,
 | 
						|
            batch_size = 1,
 | 
						|
            video_prompt_type = "",
 | 
						|
            joint_pass = False,
 | 
						|
            image_refs_relative_size = 100,       
 | 
						|
            **bbargs
 | 
						|
    ):
 | 
						|
            if self._interrupt:
 | 
						|
                return None
 | 
						|
            if self.guidance_max_phases < 1: guide_scale = 1
 | 
						|
            if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
 | 
						|
            device="cuda"
 | 
						|
            flux_dev_uso = self.name in ['flux-dev-uso']
 | 
						|
            image_stiching =  not self.name in ['flux-dev-uso'] #and False
 | 
						|
            input_ref_images = [] if input_ref_images is None else input_ref_images[:]
 | 
						|
            ref_style_imgs = []
 | 
						|
            if "I" in video_prompt_type and len(input_ref_images) > 0: 
 | 
						|
                if flux_dev_uso :
 | 
						|
                    if "J" in video_prompt_type:
 | 
						|
                        ref_style_imgs = input_ref_images
 | 
						|
                        input_ref_images = []
 | 
						|
                    elif len(input_ref_images) > 1 :
 | 
						|
                        ref_style_imgs = input_ref_images[-1:]
 | 
						|
                        input_ref_images = input_ref_images[:-1]
 | 
						|
                if image_stiching:
 | 
						|
                    # image stiching method
 | 
						|
                    stiched = input_ref_images[0]
 | 
						|
                    for new_img in input_ref_images[1:]:
 | 
						|
                        stiched = stitch_images(stiched, new_img)
 | 
						|
                    input_ref_images  = [stiched]
 | 
						|
                else:
 | 
						|
                    # latents stiching with resize 
 | 
						|
                    for i in range(len(input_ref_images)):
 | 
						|
                        w, h = input_ref_images[i].size
 | 
						|
                        image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
 | 
						|
                        input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) 
 | 
						|
            else:
 | 
						|
                input_ref_images = None
 | 
						|
 | 
						|
            if flux_dev_uso :
 | 
						|
                inp, height, width = prepare_multi_ip(
 | 
						|
                    ae=self.vae,
 | 
						|
                    img_cond_list=input_ref_images,
 | 
						|
                    target_width=width,
 | 
						|
                    target_height=height,
 | 
						|
                    bs=batch_size,
 | 
						|
                    seed=seed,
 | 
						|
                    device=device,
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                inp, height, width = prepare_kontext(
 | 
						|
                    ae=self.vae,
 | 
						|
                    img_cond_list=input_ref_images,
 | 
						|
                    target_width=width,
 | 
						|
                    target_height=height,
 | 
						|
                    bs=batch_size,
 | 
						|
                    seed=seed,
 | 
						|
                    device=device,
 | 
						|
                )
 | 
						|
 | 
						|
            inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
 | 
						|
            if guide_scale != 1:
 | 
						|
                inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
 | 
						|
 | 
						|
            timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
 | 
						|
 | 
						|
            ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
 | 
						|
            if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
 | 
						|
                # processing style feat into textural hidden space
 | 
						|
                siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
 | 
						|
                siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
 | 
						|
                siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
 | 
						|
                inp["siglip_embedding"] = siglip_embedding
 | 
						|
                inp["siglip_embedding_ids"] = siglip_embedding_ids
 | 
						|
 | 
						|
            def unpack_latent(x):
 | 
						|
                return unpack(x.float(), height, width) 
 | 
						|
 | 
						|
            # denoise initial noise
 | 
						|
            x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
 | 
						|
            if x==None: return None
 | 
						|
            # decode latents to pixel space
 | 
						|
            x = unpack_latent(x)
 | 
						|
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
 | 
						|
                x = self.vae.decode(x)
 | 
						|
 | 
						|
            x = x.clamp(-1, 1)
 | 
						|
            x = x.transpose(0, 1)
 | 
						|
            return x
 | 
						|
 |