import os import re import time from dataclasses import dataclass from glob import iglob from mmgp import offload as offload import torch from shared.utils.utils import calculate_new_dimensions from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack from .modules.layers import get_linear_split_map from transformers import SiglipVisionModel, SiglipImageProcessor import torchvision.transforms.functional as TVF import math from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image from .util import ( aspect_ratio_to_height_width, load_ae, load_clip, load_flow_model, load_t5, save_image, ) from PIL import Image def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1): target_height_ref1 = int(target_height_ref1 // 64 * 64) target_width_ref1 = int(target_width_ref1 // 64 * 64) h, w = image.shape[-2:] if h < target_height_ref1 or w < target_width_ref1: # 计算长宽比 aspect_ratio = w / h if h < target_height_ref1: new_h = target_height_ref1 new_w = new_h * aspect_ratio if new_w < target_width_ref1: new_w = target_width_ref1 new_h = new_w / aspect_ratio else: new_w = target_width_ref1 new_h = new_w / aspect_ratio if new_h < target_height_ref1: new_h = target_height_ref1 new_w = new_h * aspect_ratio else: aspect_ratio = w / h tgt_aspect_ratio = target_width_ref1 / target_height_ref1 if aspect_ratio > tgt_aspect_ratio: new_h = target_height_ref1 new_w = new_h * aspect_ratio else: new_w = target_width_ref1 new_h = new_w / aspect_ratio # 使用 TVF.resize 进行图像缩放 image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w))) # 计算中心裁剪的参数 top = (image.shape[-2] - target_height_ref1) // 2 left = (image.shape[-1] - target_width_ref1) // 2 # 使用 TVF.crop 进行中心裁剪 image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1) return image def stitch_images(img1, img2): # Resize img2 to match img1's height width1, height1 = img1.size width2, height2 = img2.size new_width2 = int(width2 * height1 / height2) img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) stitched = Image.new('RGB', (width1 + new_width2, height1)) stitched.paste(img1, (0, 0)) stitched.paste(img2_resized, (width1, 0)) return stitched class model_factory: def __init__( self, checkpoint_dir, model_filename = None, model_type = None, model_def = None, base_model_type = None, text_encoder_filename = None, quantizeTransformer = False, save_quantized = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False ): self.device = torch.device(f"cuda") self.VAE_dtype = VAE_dtype self.dtype = dtype torch_device = "cpu" self.guidance_max_phases = model_def.get("guidance_max_phases", 0) # model_filename = ["c:/temp/flux1-schnell.safetensors"] self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) self.clip = load_clip(torch_device) self.name = model_def.get("flux-model", "flux-dev") # self.name= "flux-dev-kontext" # self.name= "flux-dev" # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) self.vae = load_ae(self.name, device=torch_device) siglip_processor = siglip_model = feature_embedder = None if self.name == 'flux-dev-uso': siglip_path = "ckpts/siglip-so400m-patch14-384" siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) siglip_model = SiglipVisionModel.from_pretrained(siglip_path) siglip_model.eval().to("cpu") if len(model_filename) > 1: from .modules.layers import SigLIPMultiFeatProjModel feature_embedder = SigLIPMultiFeatProjModel( siglip_token_nums=729, style_token_nums=64, siglip_token_dims=1152, hidden_size=3072, #self.hidden_size, context_layer_norm=True, ) offload.load_model_data(feature_embedder, model_filename[1]) self.vision_encoder = siglip_model self.vision_encoder_processor = siglip_processor self.feature_embedder = feature_embedder # offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "flux-dev.safetensors") if not source is None: from wgp import save_model save_model(self.model, model_type, dtype, None) if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, None) split_linear_modules_map = get_linear_split_map() self.model.split_linear_modules_map = split_linear_modules_map offload.split_linear_modules(self.model, split_linear_modules_map ) def generate( self, seed: int | None = None, input_prompt: str = "replace the logo with the text 'Black Forest Labs'", n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, width= 832, height=480, embedded_guidance_scale: float = 2.5, guide_scale = 2.5, fit_into_canvas = None, callback = None, loras_slists = None, batch_size = 1, video_prompt_type = "", joint_pass = False, image_refs_relative_size = 100, **bbargs ): if self._interrupt: return None if self.guidance_max_phases < 1: guide_scale = 1 if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] image_stiching = not self.name in ['flux-dev-uso'] #and False # image_refs_relative_size = 100 crop = False input_ref_images = [] if input_ref_images is None else input_ref_images[:] ref_style_imgs = [] if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : if "J" in video_prompt_type: ref_style_imgs = input_ref_images input_ref_images = [] elif len(input_ref_images) > 1 : ref_style_imgs = input_ref_images[-1:] input_ref_images = input_ref_images[:-1] if image_stiching: # image stiching method stiched = input_ref_images[0] if "K" in video_prompt_type : w, h = input_ref_images[0].size height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) # actual rescale will happen in prepare_kontext for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] else: first_ref = 0 if "K" in video_prompt_type: # image latents tiling method w, h = input_ref_images[0].size if crop : img = convert_image_to_tensor(input_ref_images[0]) img = resize_and_centercrop_image(img, height, width) input_ref_images[0] = convert_tensor_to_image(img) else: height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) first_ref = 1 for i in range(first_ref,len(input_ref_images)): w, h = input_ref_images[i].size if crop: img = convert_image_to_tensor(input_ref_images[i]) img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100)) input_ref_images[i] = convert_tensor_to_image(img) else: image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) else: input_ref_images = None if flux_dev_uso : inp, height, width = prepare_multi_ip( ae=self.vae, img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, seed=seed, device=device, ) else: inp, height, width = prepare_kontext( ae=self.vae, img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, seed=seed, device=device, ) inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) if guide_scale != 1: inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device)) timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None: # processing style feat into textural hidden space siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs] siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device) inp["siglip_embedding"] = siglip_embedding inp["siglip_embedding_ids"] = siglip_embedding_ids def unpack_latent(x): return unpack(x.float(), height, width) # denoise initial noise x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) x = x.clamp(-1, 1) x = x.transpose(0, 1) return x