mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			429 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			429 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import math
 | 
						|
from typing import Callable
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
from einops import rearrange, repeat
 | 
						|
from PIL import Image
 | 
						|
from torch import Tensor
 | 
						|
 | 
						|
from .model import Flux
 | 
						|
from .modules.autoencoder import AutoEncoder
 | 
						|
from .modules.conditioner import HFEmbedder
 | 
						|
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
 | 
						|
from .util import PREFERED_KONTEXT_RESOLUTIONS
 | 
						|
from einops import rearrange, repeat
 | 
						|
from typing import Literal
 | 
						|
import torchvision.transforms.functional as TVF
 | 
						|
 | 
						|
 | 
						|
def get_noise(
 | 
						|
    num_samples: int,
 | 
						|
    height: int,
 | 
						|
    width: int,
 | 
						|
    device: torch.device,
 | 
						|
    dtype: torch.dtype,
 | 
						|
    seed: int,
 | 
						|
):
 | 
						|
    return torch.randn(
 | 
						|
        num_samples,
 | 
						|
        16,
 | 
						|
        # allow for packing
 | 
						|
        2 * math.ceil(height / 16),
 | 
						|
        2 * math.ceil(width / 16),
 | 
						|
        dtype=dtype,
 | 
						|
        device=device,
 | 
						|
        generator=torch.Generator(device=device).manual_seed(seed),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]:
 | 
						|
    if bs == 1 and not isinstance(prompt, str):
 | 
						|
        bs = len(prompt)
 | 
						|
 | 
						|
    if isinstance(prompt, str):
 | 
						|
        prompt = [prompt]
 | 
						|
    txt = t5(prompt)
 | 
						|
    if txt.shape[0] == 1 and bs > 1:
 | 
						|
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
 | 
						|
    txt_ids = torch.zeros(bs, txt.shape[1], 3)
 | 
						|
 | 
						|
    vec = clip(prompt)
 | 
						|
    if vec.shape[0] == 1 and bs > 1:
 | 
						|
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    return {
 | 
						|
        "neg_txt" if neg else "txt": txt.to(device),
 | 
						|
        "neg_txt_ids" if neg else "txt_ids": txt_ids.to(device),
 | 
						|
        "neg_vec" if neg else "vec": vec.to(device),
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def prepare_img( img: Tensor) -> dict[str, Tensor]:
 | 
						|
    bs, c, h, w = img.shape
 | 
						|
 | 
						|
    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 | 
						|
    if img.shape[0] == 1 and bs > 1:
 | 
						|
        img = repeat(img, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    img_ids = torch.zeros(h // 2, w // 2, 3)
 | 
						|
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
 | 
						|
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
 | 
						|
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
 | 
						|
 | 
						|
    return {
 | 
						|
        "img": img,
 | 
						|
        "img_ids": img_ids.to(img.device),
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def prepare_redux(
 | 
						|
    t5: HFEmbedder,
 | 
						|
    clip: HFEmbedder,
 | 
						|
    img: Tensor,
 | 
						|
    prompt: str | list[str],
 | 
						|
    encoder: ReduxImageEncoder,
 | 
						|
    img_cond_path: str,
 | 
						|
) -> dict[str, Tensor]:
 | 
						|
    bs, _, h, w = img.shape
 | 
						|
    if bs == 1 and not isinstance(prompt, str):
 | 
						|
        bs = len(prompt)
 | 
						|
 | 
						|
    img_cond = Image.open(img_cond_path).convert("RGB")
 | 
						|
    with torch.no_grad():
 | 
						|
        img_cond = encoder(img_cond)
 | 
						|
 | 
						|
    img_cond = img_cond.to(torch.bfloat16)
 | 
						|
    if img_cond.shape[0] == 1 and bs > 1:
 | 
						|
        img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 | 
						|
    if img.shape[0] == 1 and bs > 1:
 | 
						|
        img = repeat(img, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    img_ids = torch.zeros(h // 2, w // 2, 3)
 | 
						|
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
 | 
						|
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
 | 
						|
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
 | 
						|
 | 
						|
    if isinstance(prompt, str):
 | 
						|
        prompt = [prompt]
 | 
						|
    txt = t5(prompt)
 | 
						|
    txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
 | 
						|
    if txt.shape[0] == 1 and bs > 1:
 | 
						|
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
 | 
						|
    txt_ids = torch.zeros(bs, txt.shape[1], 3)
 | 
						|
 | 
						|
    vec = clip(prompt)
 | 
						|
    if vec.shape[0] == 1 and bs > 1:
 | 
						|
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    return {
 | 
						|
        "img": img,
 | 
						|
        "img_ids": img_ids.to(img.device),
 | 
						|
        "txt": txt.to(img.device),
 | 
						|
        "txt_ids": txt_ids.to(img.device),
 | 
						|
        "vec": vec.to(img.device),
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def prepare_kontext(
 | 
						|
    ae: AutoEncoder,
 | 
						|
    img_cond_list: list,
 | 
						|
    seed: int,
 | 
						|
    device: torch.device,
 | 
						|
    target_width: int | None = None,
 | 
						|
    target_height: int | None = None,
 | 
						|
    bs: int = 1,
 | 
						|
 | 
						|
) -> tuple[dict[str, Tensor], int, int]:
 | 
						|
    # load and encode the conditioning image
 | 
						|
 | 
						|
    img_cond_seq = None
 | 
						|
    img_cond_seq_ids = None
 | 
						|
    if img_cond_list == None: img_cond_list = []
 | 
						|
    height_offset = 0
 | 
						|
    width_offset = 0
 | 
						|
    for cond_no, img_cond in enumerate(img_cond_list): 
 | 
						|
        width, height = img_cond.size
 | 
						|
        aspect_ratio = width / height
 | 
						|
 | 
						|
        # Kontext is trained on specific resolutions, using one of them is recommended
 | 
						|
        _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
 | 
						|
        width = 2 * int(width / 16)
 | 
						|
        height = 2 * int(height / 16)
 | 
						|
 | 
						|
        img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS)
 | 
						|
        img_cond = np.array(img_cond)
 | 
						|
        img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
 | 
						|
        img_cond = rearrange(img_cond, "h w c -> 1 c h w")
 | 
						|
        with torch.no_grad():
 | 
						|
            img_cond_latents = ae.encode(img_cond.to(device))
 | 
						|
 | 
						|
        img_cond_latents = img_cond_latents.to(torch.bfloat16)
 | 
						|
        img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 | 
						|
        if img_cond.shape[0] == 1 and bs > 1:
 | 
						|
            img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs)
 | 
						|
        img_cond = None
 | 
						|
 | 
						|
        # image ids are the same as base image with the first dimension set to 1
 | 
						|
        # instead of 0
 | 
						|
        img_cond_ids = torch.zeros(height // 2, width // 2, 3)
 | 
						|
        img_cond_ids[..., 0] = 1
 | 
						|
        img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset
 | 
						|
        img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset
 | 
						|
        img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
 | 
						|
        height_offset +=  height // 2 
 | 
						|
        width_offset +=  width // 2
 | 
						|
 | 
						|
        if target_width is None:
 | 
						|
            target_width = 8 * width
 | 
						|
        if target_height is None:
 | 
						|
            target_height = 8 * height
 | 
						|
        img_cond_ids = img_cond_ids.to(device)
 | 
						|
        if cond_no == 0:
 | 
						|
            img_cond_seq, img_cond_seq_ids  = img_cond_latents, img_cond_ids
 | 
						|
        else:
 | 
						|
            img_cond_seq, img_cond_seq_ids  =  torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
 | 
						|
        
 | 
						|
    return_dict = {
 | 
						|
        "img_cond_seq": img_cond_seq,
 | 
						|
        "img_cond_seq_ids": img_cond_seq_ids,
 | 
						|
    }
 | 
						|
    img = get_noise(
 | 
						|
        bs,
 | 
						|
        target_height,
 | 
						|
        target_width,
 | 
						|
        device=device,
 | 
						|
        dtype=torch.bfloat16,
 | 
						|
        seed=seed,
 | 
						|
    )
 | 
						|
    return_dict.update(prepare_img(img))
 | 
						|
 | 
						|
    return return_dict, target_height, target_width
 | 
						|
 | 
						|
 | 
						|
def time_shift(mu: float, sigma: float, t: Tensor):
 | 
						|
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
 | 
						|
 | 
						|
 | 
						|
def get_lin_function(
 | 
						|
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
 | 
						|
) -> Callable[[float], float]:
 | 
						|
    m = (y2 - y1) / (x2 - x1)
 | 
						|
    b = y1 - m * x1
 | 
						|
    return lambda x: m * x + b
 | 
						|
 | 
						|
 | 
						|
def get_schedule(
 | 
						|
    num_steps: int,
 | 
						|
    image_seq_len: int,
 | 
						|
    base_shift: float = 0.5,
 | 
						|
    max_shift: float = 1.15,
 | 
						|
    shift: bool = True,
 | 
						|
) -> list[float]:
 | 
						|
    # extra step for zero
 | 
						|
    timesteps = torch.linspace(1, 0, num_steps + 1)
 | 
						|
 | 
						|
    # shifting the schedule to favor high timesteps for higher signal images
 | 
						|
    if shift:
 | 
						|
        # estimate mu based on linear estimation between two points
 | 
						|
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
 | 
						|
        timesteps = time_shift(mu, 1.0, timesteps)
 | 
						|
 | 
						|
    return timesteps.tolist()
 | 
						|
 | 
						|
 | 
						|
def denoise(
 | 
						|
    model: Flux,
 | 
						|
    # model input
 | 
						|
    img: Tensor,
 | 
						|
    img_ids: Tensor,
 | 
						|
    txt: Tensor,
 | 
						|
    txt_ids: Tensor,
 | 
						|
    vec: Tensor,
 | 
						|
    # sampling parameters
 | 
						|
    timesteps: list[float],
 | 
						|
    guidance: float = 4.0,
 | 
						|
    real_guidance_scale = None,
 | 
						|
    # extra img tokens (channel-wise)
 | 
						|
    neg_txt: Tensor = None,
 | 
						|
    neg_txt_ids: Tensor= None,
 | 
						|
    neg_vec: Tensor = None,
 | 
						|
    img_cond: Tensor | None = None,
 | 
						|
    # extra img tokens (sequence-wise)
 | 
						|
    img_cond_seq: Tensor | None = None,
 | 
						|
    img_cond_seq_ids: Tensor | None = None,
 | 
						|
    siglip_embedding = None,
 | 
						|
    siglip_embedding_ids = None,
 | 
						|
    callback=None,
 | 
						|
    pipeline=None,
 | 
						|
    loras_slists=None,
 | 
						|
    unpack_latent = None,
 | 
						|
    joint_pass= False,
 | 
						|
):
 | 
						|
 | 
						|
    kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
 | 
						|
 | 
						|
    if callback != None:
 | 
						|
        callback(-1, None, True)
 | 
						|
 | 
						|
    updated_num_steps= len(timesteps) -1
 | 
						|
    if callback != None:
 | 
						|
        from shared.utils.loras_mutipliers import update_loras_slists
 | 
						|
        update_loras_slists(model, loras_slists, updated_num_steps)
 | 
						|
        callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
						|
    from mmgp import offload
 | 
						|
    # this is ignored for schnell
 | 
						|
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
 | 
						|
    for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
 | 
						|
        offload.set_step_no_for_lora(model, i)
 | 
						|
        if pipeline._interrupt:
 | 
						|
            return None
 | 
						|
 | 
						|
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
 | 
						|
        img_input = img
 | 
						|
        img_input_ids = img_ids
 | 
						|
        if img_cond is not None:
 | 
						|
            img_input = torch.cat((img, img_cond), dim=-1)
 | 
						|
        if img_cond_seq is not None:
 | 
						|
            img_input = torch.cat((img_input, img_cond_seq), dim=1)
 | 
						|
            img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
 | 
						|
        if not joint_pass or real_guidance_scale == 1:
 | 
						|
            pred = model(
 | 
						|
                img=img_input,
 | 
						|
                img_ids=img_input_ids,
 | 
						|
                txt_list=[txt],
 | 
						|
                txt_ids_list=[txt_ids],
 | 
						|
                y_list=[vec],
 | 
						|
                timesteps=t_vec,
 | 
						|
                guidance=guidance_vec,
 | 
						|
                **kwargs
 | 
						|
            )[0]
 | 
						|
            if pred == None: return None
 | 
						|
            if real_guidance_scale> 1:
 | 
						|
                neg_pred = model(
 | 
						|
                    img=img_input,
 | 
						|
                    img_ids=img_input_ids,
 | 
						|
                    txt_list=[neg_txt],
 | 
						|
                    txt_ids_list=[neg_txt_ids],
 | 
						|
                    y_list=[neg_vec],
 | 
						|
                    timesteps=t_vec,
 | 
						|
                    guidance=guidance_vec,
 | 
						|
                    **kwargs
 | 
						|
                )[0]
 | 
						|
                if neg_pred == None: return None
 | 
						|
        else:
 | 
						|
            pred, neg_pred = model(
 | 
						|
                img=img_input,
 | 
						|
                img_ids=img_input_ids,
 | 
						|
                txt_list=[txt, neg_txt],
 | 
						|
                txt_ids_list=[txt_ids, neg_txt_ids],
 | 
						|
                y_list=[vec, neg_vec],
 | 
						|
                timesteps=t_vec,
 | 
						|
                guidance=guidance_vec,
 | 
						|
                **kwargs
 | 
						|
            )
 | 
						|
            if pred == None: return None
 | 
						|
 | 
						|
        if real_guidance_scale > 1:
 | 
						|
            pred = neg_pred + real_guidance_scale * (pred - neg_pred)
 | 
						|
 | 
						|
        img += (t_prev - t_curr) * pred
 | 
						|
        if callback is not None:
 | 
						|
            preview = unpack_latent(img).transpose(0,1)
 | 
						|
            callback(i, preview, False)         
 | 
						|
 | 
						|
 | 
						|
    return img
 | 
						|
 | 
						|
def prepare_multi_ip(
 | 
						|
    ae: AutoEncoder,
 | 
						|
    img_cond_list: list,
 | 
						|
    seed: int,
 | 
						|
    device: torch.device,
 | 
						|
    target_width: int | None = None,
 | 
						|
    target_height: int | None = None,
 | 
						|
    bs: int = 1,
 | 
						|
    pe: Literal["d", "h", "w", "o"] = "d",
 | 
						|
) -> dict[str, Tensor]:
 | 
						|
    ref_imgs = img_cond_list
 | 
						|
    assert pe in ["d", "h", "w", "o"]
 | 
						|
 | 
						|
    ref_imgs = [
 | 
						|
        ae.encode(
 | 
						|
            (TVF.to_tensor(ref_img) * 2.0 - 1.0)
 | 
						|
            .unsqueeze(0)
 | 
						|
            .to(device, torch.float32)
 | 
						|
        ).to(torch.bfloat16)
 | 
						|
        for ref_img in img_cond_list
 | 
						|
    ]
 | 
						|
 | 
						|
    img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed)
 | 
						|
    bs, c, h, w = img.shape
 | 
						|
    # tgt img
 | 
						|
    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 | 
						|
    if img.shape[0] == 1 and bs > 1:
 | 
						|
        img = repeat(img, "1 ... -> bs ...", bs=bs)
 | 
						|
 | 
						|
    img_ids = torch.zeros(h // 2, w // 2, 3)
 | 
						|
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
 | 
						|
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
 | 
						|
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
 | 
						|
    img_cond_seq = img_cond_seq_ids = None
 | 
						|
    pe_shift_w, pe_shift_h = w // 2, h // 2
 | 
						|
    for cond_no, ref_img in enumerate(ref_imgs):
 | 
						|
        _, _, ref_h1, ref_w1 = ref_img.shape
 | 
						|
        ref_img = rearrange(
 | 
						|
            ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
 | 
						|
        )
 | 
						|
        if ref_img.shape[0] == 1 and bs > 1:
 | 
						|
            ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
 | 
						|
        ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
 | 
						|
        # img id分别在宽高偏移各自最大值
 | 
						|
        h_offset = pe_shift_h if pe in {"d", "h"} else 0
 | 
						|
        w_offset = pe_shift_w if pe in {"d", "w"} else 0
 | 
						|
        ref_img_ids1[..., 1] = (
 | 
						|
            ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
 | 
						|
        )
 | 
						|
        ref_img_ids1[..., 2] = (
 | 
						|
            ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
 | 
						|
        )
 | 
						|
        ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
 | 
						|
 | 
						|
        if target_width is None:
 | 
						|
            target_width = 8 * ref_w1
 | 
						|
        if target_height is None:
 | 
						|
            target_height = 8 * ref_h1
 | 
						|
        ref_img_ids1 = ref_img_ids1.to(device)
 | 
						|
        if cond_no == 0:
 | 
						|
            img_cond_seq, img_cond_seq_ids  = ref_img, ref_img_ids1
 | 
						|
        else:
 | 
						|
            img_cond_seq, img_cond_seq_ids  =  torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1)
 | 
						|
 | 
						|
 | 
						|
        # 更新pe shift
 | 
						|
        pe_shift_h += ref_h1 // 2
 | 
						|
        pe_shift_w += ref_w1 // 2
 | 
						|
 | 
						|
    return {
 | 
						|
        "img": img,
 | 
						|
        "img_ids": img_ids.to(img.device),
 | 
						|
        "img_cond_seq": img_cond_seq,
 | 
						|
        "img_cond_seq_ids": img_cond_seq_ids,
 | 
						|
    }, target_height, target_width
 | 
						|
 | 
						|
 | 
						|
def unpack(x: Tensor, height: int, width: int) -> Tensor:
 | 
						|
    return rearrange(
 | 
						|
        x,
 | 
						|
        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
 | 
						|
        h=math.ceil(height / 16),
 | 
						|
        w=math.ceil(width / 16),
 | 
						|
        ph=2,
 | 
						|
        pw=2,
 | 
						|
    )
 |