import math from typing import Callable import numpy as np import torch from einops import rearrange, repeat from PIL import Image from torch import Tensor from .model import Flux from .modules.autoencoder import AutoEncoder from .modules.conditioner import HFEmbedder from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder from .util import PREFERED_KONTEXT_RESOLUTIONS from einops import rearrange, repeat from typing import Literal import torchvision.transforms.functional as TVF def get_noise( num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int, ): return torch.randn( num_samples, 16, # allow for packing 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype=dtype, device=device, generator=torch.Generator(device=device).manual_seed(seed), ) def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]: if bs == 1 and not isinstance(prompt, str): bs = len(prompt) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "neg_txt" if neg else "txt": txt.to(device), "neg_txt_ids" if neg else "txt_ids": txt_ids.to(device), "neg_vec" if neg else "vec": vec.to(device), } def prepare_img( img: Tensor) -> dict[str, Tensor]: bs, c, h, w = img.shape img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) return { "img": img, "img_ids": img_ids.to(img.device), } def prepare_redux( t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], encoder: ReduxImageEncoder, img_cond_path: str, ) -> dict[str, Tensor]: bs, _, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img_cond = Image.open(img_cond_path).convert("RGB") with torch.no_grad(): img_cond = encoder(img_cond) img_cond = img_cond.to(torch.bfloat16) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) txt = torch.cat((txt, img_cond.to(txt)), dim=-2) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def prepare_kontext( ae: AutoEncoder, img_cond_list: list, seed: int, device: torch.device, target_width: int | None = None, target_height: int | None = None, bs: int = 1, img_mask = None, ) -> tuple[dict[str, Tensor], int, int]: # load and encode the conditioning image res_match_output = img_mask is not None img_cond_seq = None img_cond_seq_ids = None if img_cond_list == None: img_cond_list = [] height_offset = 0 width_offset = 0 for cond_no, img_cond in enumerate(img_cond_list): width, height = img_cond.size aspect_ratio = width / height if res_match_output: width, height = target_width, target_height else: # Kontext is trained on specific resolutions, using one of them is recommended _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) width = 2 * int(width / 16) height = 2 * int(height / 16) img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS) img_cond = np.array(img_cond) img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 img_cond = rearrange(img_cond, "h w c -> 1 c h w") with torch.no_grad(): img_cond_latents = ae.encode(img_cond.to(device)) img_cond_latents = img_cond_latents.to(torch.bfloat16) img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs) img_cond = None # image ids are the same as base image with the first dimension set to 1 # instead of 0 img_cond_ids = torch.zeros(height // 2, width // 2, 3) img_cond_ids[..., 0] = 1 img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) height_offset += height // 2 width_offset += width // 2 if target_width is None: target_width = 8 * width if target_height is None: target_height = 8 * height img_cond_ids = img_cond_ids.to(device) if cond_no == 0: img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids else: img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1) return_dict = { "img_cond_seq": img_cond_seq, "img_cond_seq_ids": img_cond_seq_ids, } if img_mask is not None: from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) return_dict.update({ "img_msk_latents": image_mask_latents, "img_msk_rebuilt": image_mask_rebuilt, }) img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed, ) return_dict.update(prepare_img(img)) return return_dict, target_height, target_width def time_shift(mu: float, sigma: float, t: Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 ) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if shift: # estimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], guidance: float = 4.0, real_guidance_scale = None, # extra img tokens (channel-wise) neg_txt: Tensor = None, neg_txt_ids: Tensor= None, neg_vec: Tensor = None, img_cond: Tensor | None = None, # extra img tokens (sequence-wise) img_cond_seq: Tensor | None = None, img_cond_seq_ids: Tensor | None = None, siglip_embedding = None, siglip_embedding_ids = None, callback=None, pipeline=None, loras_slists=None, unpack_latent = None, joint_pass= False, img_msk_latents = None, img_msk_rebuilt = None, denoising_strength = 1, ): kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids} if callback != None: callback(-1, None, True) original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() morph, first_step = False, 0 if img_msk_latents is not None: randn = torch.randn_like(original_image_latents) if denoising_strength < 1.: first_step = int(len(timesteps) * (1. - denoising_strength)) if not morph: latent_noise_factor = timesteps[first_step] latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor img = latents.to(img) latents = None timesteps = timesteps[first_step:] updated_num_steps= len(timesteps) -1 if callback != None: from shared.utils.loras_mutipliers import update_loras_slists update_loras_slists(model, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): offload.set_step_no_for_lora(model, first_step + i) if pipeline._interrupt: return None if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph: latent_noise_factor = t_curr/1000 img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) img_input = img img_input_ids = img_ids if img_cond is not None: img_input = torch.cat((img, img_cond), dim=-1) if img_cond_seq is not None: img_input = torch.cat((img_input, img_cond_seq), dim=1) img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) if not joint_pass or real_guidance_scale == 1: pred = model( img=img_input, img_ids=img_input_ids, txt_list=[txt], txt_ids_list=[txt_ids], y_list=[vec], timesteps=t_vec, guidance=guidance_vec, **kwargs )[0] if pred == None: return None if real_guidance_scale> 1: neg_pred = model( img=img_input, img_ids=img_input_ids, txt_list=[neg_txt], txt_ids_list=[neg_txt_ids], y_list=[neg_vec], timesteps=t_vec, guidance=guidance_vec, **kwargs )[0] if neg_pred == None: return None else: pred, neg_pred = model( img=img_input, img_ids=img_input_ids, txt_list=[txt, neg_txt], txt_ids_list=[txt_ids, neg_txt_ids], y_list=[vec, neg_vec], timesteps=t_vec, guidance=guidance_vec, **kwargs ) if pred == None: return None if real_guidance_scale > 1: pred = neg_pred + real_guidance_scale * (pred - neg_pred) img += (t_prev - t_curr) * pred if img_msk_latents is not None: latent_noise_factor = t_prev # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor img = noisy_image * (1-img_msk_latents) + img_msk_latents * img noisy_image = None if callback is not None: preview = unpack_latent(img).transpose(0,1) callback(i, preview, False) return img def prepare_multi_ip( ae: AutoEncoder, img_cond_list: list, seed: int, device: torch.device, target_width: int | None = None, target_height: int | None = None, bs: int = 1, pe: Literal["d", "h", "w", "o"] = "d", ) -> dict[str, Tensor]: ref_imgs = img_cond_list assert pe in ["d", "h", "w", "o"] ref_imgs = [ ae.encode( (TVF.to_tensor(ref_img) * 2.0 - 1.0) .unsqueeze(0) .to(device, torch.float32) ).to(torch.bfloat16) for ref_img in img_cond_list ] img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed) bs, c, h, w = img.shape # tgt img img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) img_cond_seq = img_cond_seq_ids = None pe_shift_w, pe_shift_h = w // 2, h // 2 for cond_no, ref_img in enumerate(ref_imgs): _, _, ref_h1, ref_w1 = ref_img.shape ref_img = rearrange( ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2 ) if ref_img.shape[0] == 1 and bs > 1: ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs) ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3) # img id分别在宽高偏移各自最大值 h_offset = pe_shift_h if pe in {"d", "h"} else 0 w_offset = pe_shift_w if pe in {"d", "w"} else 0 ref_img_ids1[..., 1] = ( ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset ) ref_img_ids1[..., 2] = ( ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset ) ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs) if target_width is None: target_width = 8 * ref_w1 if target_height is None: target_height = 8 * ref_h1 ref_img_ids1 = ref_img_ids1.to(device) if cond_no == 0: img_cond_seq, img_cond_seq_ids = ref_img, ref_img_ids1 else: img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1) # 更新pe shift pe_shift_h += ref_h1 // 2 pe_shift_w += ref_w1 // 2 return { "img": img, "img_ids": img_ids.to(img.device), "img_cond_seq": img_cond_seq, "img_cond_seq_ids": img_cond_seq_ids, }, target_height, target_width def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, )