Wan2.1/models/flux/flux_main.py
2025-09-11 21:23:05 +02:00

263 lines
11 KiB
Python

import os
import re
import time
from dataclasses import dataclass
from glob import iglob
from mmgp import offload as offload
import torch
from shared.utils.utils import calculate_new_dimensions
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
import torchvision.transforms.functional as TVF
import math
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
from .util import (
aspect_ratio_to_height_width,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
from PIL import Image
def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
# 获取原始图像的宽度和高度
image_w, image_h = raw_image.size
# 计算长边和短边
if image_w >= image_h:
new_w = long_size
new_h = int((long_size / image_w) * image_h)
else:
new_h = long_size
new_w = int((long_size / image_h) * image_w)
# 按新的宽高进行等比例缩放
raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
target_w = new_w // 16 * 16
target_h = new_h // 16 * 16
# 计算裁剪的起始坐标以实现中心裁剪
left = (new_w - target_w) // 2
top = (new_h - target_h) // 2
right = left + target_w
bottom = top + target_h
# 进行中心裁剪
raw_image = raw_image.crop((left, top, right, bottom))
# 转换为 RGB 模式
raw_image = raw_image.convert("RGB")
return raw_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
width2, height2 = img2.size
new_width2 = int(width2 * height1 / height2)
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
stitched = Image.new('RGB', (width1 + new_width2, height1))
stitched.paste(img1, (0, 0))
stitched.paste(img2_resized, (width1, 0))
return stitched
class model_factory:
def __init__(
self,
checkpoint_dir,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
self.device = torch.device(f"cuda")
self.VAE_dtype = VAE_dtype
self.dtype = dtype
torch_device = "cpu"
self.guidance_max_phases = model_def.get("guidance_max_phases", 0)
# model_filename = ["c:/temp/flux1-schnell.safetensors"]
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
self.clip = load_clip(torch_device)
self.name = model_def.get("flux-model", "flux-dev")
# self.name= "flux-dev-kontext"
# self.name= "flux-dev"
# self.name= "flux-schnell"
source = model_def.get("source", None)
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
self.model_def = model_def
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
if self.name == 'flux-dev-uso':
siglip_path = "ckpts/siglip-so400m-patch14-384"
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
siglip_model.eval().to("cpu")
if len(model_filename) > 1:
from .modules.layers import SigLIPMultiFeatProjModel
feature_embedder = SigLIPMultiFeatProjModel(
siglip_token_nums=729,
style_token_nums=64,
siglip_token_dims=1152,
hidden_size=3072, #self.hidden_size,
context_layer_norm=True,
)
offload.load_model_data(feature_embedder, model_filename[1])
self.vision_encoder = siglip_model
self.vision_encoder_processor = siglip_processor
self.feature_embedder = feature_embedder
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "flux-dev.safetensors")
if not source is None:
from wgp import save_model
save_model(self.model, model_type, dtype, None)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
split_linear_modules_map = get_linear_split_map()
self.model.split_linear_modules_map = split_linear_modules_map
offload.split_linear_modules(self.model, split_linear_modules_map )
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
image_guide= None,
image_mask= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
guide_scale = 2.5,
fit_into_canvas = None,
callback = None,
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
image_refs_relative_size = 100,
denoising_strength = 1.,
**bbargs
):
if self._interrupt:
return None
if self.guidance_max_phases < 1: guide_scale = 1
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
flux_dev_umo = self.name in ['flux-dev-umo']
latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo']
lock_dimensions= False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
if flux_dev_umo:
ref_long_side = 512 if len(input_ref_images) <= 1 else 320
input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images]
lock_dimensions = True
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
ref_style_imgs = input_ref_images
input_ref_images = []
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
if latent_stiching:
# latents stiching with resize
if not lock_dimensions :
for i in range(len(input_ref_images)):
w, h = input_ref_images[i].size
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0)
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
# image stiching method
stiched = input_ref_images[0]
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
elif image_guide is not None:
input_ref_images = [image_guide]
else:
input_ref_images = None
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
else:
inp, height, width = prepare_kontext(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
img_mask=image_mask,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
if guide_scale != 1:
inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
# processing style feat into textural hidden space
siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
inp["siglip_embedding"] = siglip_embedding
inp["siglip_embedding_ids"] = siglip_embedding_ids
def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
if image_mask is not None:
from shared.utils.utils import convert_image_to_tensor
img_msk_rebuilt = inp["img_msk_rebuilt"]
img= convert_image_to_tensor(image_guide)
x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x