Wan2.1/models/flux/flux_main.py
2025-09-03 19:39:17 +02:00

273 lines
12 KiB
Python

import os
import re
import time
from dataclasses import dataclass
from glob import iglob
from mmgp import offload as offload
import torch
from shared.utils.utils import calculate_new_dimensions
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
import torchvision.transforms.functional as TVF
import math
from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
from .util import (
aspect_ratio_to_height_width,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
from PIL import Image
def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
target_height_ref1 = int(target_height_ref1 // 64 * 64)
target_width_ref1 = int(target_width_ref1 // 64 * 64)
h, w = image.shape[-2:]
if h < target_height_ref1 or w < target_width_ref1:
# 计算长宽比
aspect_ratio = w / h
if h < target_height_ref1:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
if new_w < target_width_ref1:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
else:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
if new_h < target_height_ref1:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
else:
aspect_ratio = w / h
tgt_aspect_ratio = target_width_ref1 / target_height_ref1
if aspect_ratio > tgt_aspect_ratio:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
else:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
# 使用 TVF.resize 进行图像缩放
image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
# 计算中心裁剪的参数
top = (image.shape[-2] - target_height_ref1) // 2
left = (image.shape[-1] - target_width_ref1) // 2
# 使用 TVF.crop 进行中心裁剪
image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
return image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
width2, height2 = img2.size
new_width2 = int(width2 * height1 / height2)
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
stitched = Image.new('RGB', (width1 + new_width2, height1))
stitched.paste(img1, (0, 0))
stitched.paste(img2_resized, (width1, 0))
return stitched
class model_factory:
def __init__(
self,
checkpoint_dir,
model_filename = None,
model_type = None,
model_def = None,
base_model_type = None,
text_encoder_filename = None,
quantizeTransformer = False,
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
mixed_precision_transformer = False
):
self.device = torch.device(f"cuda")
self.VAE_dtype = VAE_dtype
self.dtype = dtype
torch_device = "cpu"
self.guidance_max_phases = model_def.get("guidance_max_phases", 0)
# model_filename = ["c:/temp/flux1-schnell.safetensors"]
self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512)
self.clip = load_clip(torch_device)
self.name = model_def.get("flux-model", "flux-dev")
# self.name= "flux-dev-kontext"
# self.name= "flux-dev"
# self.name= "flux-schnell"
source = model_def.get("source", None)
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
if self.name == 'flux-dev-uso':
siglip_path = "ckpts/siglip-so400m-patch14-384"
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
siglip_model.eval().to("cpu")
if len(model_filename) > 1:
from .modules.layers import SigLIPMultiFeatProjModel
feature_embedder = SigLIPMultiFeatProjModel(
siglip_token_nums=729,
style_token_nums=64,
siglip_token_dims=1152,
hidden_size=3072, #self.hidden_size,
context_layer_norm=True,
)
offload.load_model_data(feature_embedder, model_filename[1])
self.vision_encoder = siglip_model
self.vision_encoder_processor = siglip_processor
self.feature_embedder = feature_embedder
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "flux-dev.safetensors")
if not source is None:
from wgp import save_model
save_model(self.model, model_type, dtype, None)
if save_quantized:
from wgp import save_quantized_model
save_quantized_model(self.model, model_type, model_filename[0], dtype, None)
split_linear_modules_map = get_linear_split_map()
self.model.split_linear_modules_map = split_linear_modules_map
offload.split_linear_modules(self.model, split_linear_modules_map )
def generate(
self,
seed: int | None = None,
input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
guide_scale = 2.5,
fit_into_canvas = None,
callback = None,
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
image_refs_relative_size = 100,
**bbargs
):
if self._interrupt:
return None
if self.guidance_max_phases < 1: guide_scale = 1
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
image_stiching = not self.name in ['flux-dev-uso'] #and False
# image_refs_relative_size = 100
crop = False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
ref_style_imgs = input_ref_images
input_ref_images = []
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
if image_stiching:
# image stiching method
stiched = input_ref_images[0]
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
# actual rescale will happen in prepare_kontext
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
else:
first_ref = 0
if "K" in video_prompt_type:
# image latents tiling method
w, h = input_ref_images[0].size
if crop :
img = convert_image_to_tensor(input_ref_images[0])
img = resize_and_centercrop_image(img, height, width)
input_ref_images[0] = convert_tensor_to_image(img)
else:
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
first_ref = 1
for i in range(first_ref,len(input_ref_images)):
w, h = input_ref_images[i].size
if crop:
img = convert_image_to_tensor(input_ref_images[i])
img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
input_ref_images[i] = convert_tensor_to_image(img)
else:
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
input_ref_images = None
if flux_dev_uso :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
else:
inp, height, width = prepare_kontext(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
if guide_scale != 1:
inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
# processing style feat into textural hidden space
siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
inp["siglip_embedding"] = siglip_embedding
inp["siglip_embedding_ids"] = siglip_embedding_ids
def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x