import os import re import time from dataclasses import dataclass from glob import iglob from mmgp import offload as offload import torch from shared.utils.utils import calculate_new_dimensions from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack from .modules.layers import get_linear_split_map from transformers import SiglipVisionModel, SiglipImageProcessor import torchvision.transforms.functional as TVF import math from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image from .util import ( aspect_ratio_to_height_width, load_ae, load_clip, load_flow_model, load_t5, save_image, ) from PIL import Image def preprocess_ref(raw_image: Image.Image, long_size: int = 512): # 获取原始图像的宽度和高度 image_w, image_h = raw_image.size # 计算长边和短边 if image_w >= image_h: new_w = long_size new_h = int((long_size / image_w) * image_h) else: new_h = long_size new_w = int((long_size / image_h) * image_w) # 按新的宽高进行等比例缩放 raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) target_w = new_w // 16 * 16 target_h = new_h // 16 * 16 # 计算裁剪的起始坐标以实现中心裁剪 left = (new_w - target_w) // 2 top = (new_h - target_h) // 2 right = left + target_w bottom = top + target_h # 进行中心裁剪 raw_image = raw_image.crop((left, top, right, bottom)) # 转换为 RGB 模式 raw_image = raw_image.convert("RGB") return raw_image def stitch_images(img1, img2): # Resize img2 to match img1's height width1, height1 = img1.size width2, height2 = img2.size new_width2 = int(width2 * height1 / height2) img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS) stitched = Image.new('RGB', (width1 + new_width2, height1)) stitched.paste(img1, (0, 0)) stitched.paste(img2_resized, (width1, 0)) return stitched class model_factory: def __init__( self, checkpoint_dir, model_filename = None, model_type = None, model_def = None, base_model_type = None, text_encoder_filename = None, quantizeTransformer = False, save_quantized = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False ): self.device = torch.device(f"cuda") self.VAE_dtype = VAE_dtype self.dtype = dtype torch_device = "cpu" self.guidance_max_phases = model_def.get("guidance_max_phases", 0) # model_filename = ["c:/temp/flux1-schnell.safetensors"] self.t5 = load_t5(torch_device, text_encoder_filename, max_length=512) self.clip = load_clip(torch_device) self.name = model_def.get("flux-model", "flux-dev") # self.name= "flux-dev-kontext" # self.name= "flux-dev" # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) self.model_def = model_def self.vae = load_ae(self.name, device=torch_device) siglip_processor = siglip_model = feature_embedder = None if self.name == 'flux-dev-uso': siglip_path = "ckpts/siglip-so400m-patch14-384" siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) siglip_model = SiglipVisionModel.from_pretrained(siglip_path) siglip_model.eval().to("cpu") if len(model_filename) > 1: from .modules.layers import SigLIPMultiFeatProjModel feature_embedder = SigLIPMultiFeatProjModel( siglip_token_nums=729, style_token_nums=64, siglip_token_dims=1152, hidden_size=3072, #self.hidden_size, context_layer_norm=True, ) offload.load_model_data(feature_embedder, model_filename[1]) self.vision_encoder = siglip_model self.vision_encoder_processor = siglip_processor self.feature_embedder = feature_embedder # offload.change_dtype(self.model, dtype, True) # offload.save_model(self.model, "flux-dev.safetensors") if not source is None: from wgp import save_model save_model(self.model, model_type, dtype, None) if save_quantized: from wgp import save_quantized_model save_quantized_model(self.model, model_type, model_filename[0], dtype, None) split_linear_modules_map = get_linear_split_map() self.model.split_linear_modules_map = split_linear_modules_map offload.split_linear_modules(self.model, split_linear_modules_map ) def generate( self, seed: int | None = None, input_prompt: str = "replace the logo with the text 'Black Forest Labs'", n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, input_frames= None, input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, guide_scale = 2.5, fit_into_canvas = None, callback = None, loras_slists = None, batch_size = 1, video_prompt_type = "", joint_pass = False, image_refs_relative_size = 100, denoising_strength = 1., **bbargs ): if self._interrupt: return None if self.guidance_max_phases < 1: guide_scale = 1 if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] flux_dev_umo = self.name in ['flux-dev-umo'] latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo'] lock_dimensions= False input_ref_images = [] if input_ref_images is None else input_ref_images[:] if flux_dev_umo: ref_long_side = 512 if len(input_ref_images) <= 1 else 320 input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images] lock_dimensions = True ref_style_imgs = [] if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : if "J" in video_prompt_type: ref_style_imgs = input_ref_images input_ref_images = [] elif len(input_ref_images) > 1 : ref_style_imgs = input_ref_images[-1:] input_ref_images = input_ref_images[:-1] if latent_stiching: # latents stiching with resize if not lock_dimensions : for i in range(len(input_ref_images)): w, h = input_ref_images[i].size image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0) input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) else: # image stiching method stiched = input_ref_images[0] for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] elif input_frames is not None: input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( ae=self.vae, img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, seed=seed, device=device, ) else: inp, height, width = prepare_kontext( ae=self.vae, img_cond_list=input_ref_images, target_width=width, target_height=height, bs=batch_size, seed=seed, device=device, img_mask=image_mask, ) inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) if guide_scale != 1: inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device)) timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell")) ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs] if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None: # processing style feat into textural hidden space siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs] siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1) siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device) inp["siglip_embedding"] = siglip_embedding inp["siglip_embedding_ids"] = siglip_embedding_ids def unpack_latent(x): return unpack(x.float(), height, width) # denoise initial noise x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) if image_mask is not None: from shared.utils.utils import convert_image_to_tensor img_msk_rebuilt = inp["img_msk_rebuilt"] img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt x = x.clamp(-1, 1) x = x.transpose(0, 1) return x