mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
240 lines
11 KiB
Python
240 lines
11 KiB
Python
import torch
|
|
from .utils import *
|
|
from functools import partial
|
|
|
|
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
|
|
|
|
def _pack_latents(latents):
|
|
batch_size, num_channels_latents, _, height, width = latents.shape
|
|
|
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
|
|
|
return latents
|
|
|
|
def _unpack_latents(latents, height, width, vae_scale_factor=8):
|
|
batch_size, num_patches, channels = latents.shape
|
|
|
|
height = 2 * (int(height) // (vae_scale_factor * 2))
|
|
width = 2 * (int(width) // (vae_scale_factor * 2))
|
|
|
|
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
|
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
|
|
|
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
|
|
|
|
return latents
|
|
|
|
class LanPaint():
|
|
def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
|
|
self.n_steps = NSteps
|
|
self.chara_lamb = Lambda
|
|
self.IS_FLUX = IS_FLUX
|
|
self.IS_FLOW = IS_FLOW
|
|
self.step_size = StepSize
|
|
self.friction = Friction
|
|
self.chara_beta = Beta
|
|
self.img_dim_size = None
|
|
def add_none_dims(self, array):
|
|
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
|
|
index = (slice(None),) + (None,) * (self.img_dim_size-1)
|
|
return array[index]
|
|
def remove_none_dims(self, array):
|
|
# Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
|
|
index = (slice(None),) + (0,) * (self.img_dim_size-1)
|
|
return array[index]
|
|
def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
|
|
latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
|
noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
|
x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
|
latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
|
|
self.height = height
|
|
self.width = width
|
|
self.vae_scale_factor = vae_scale_factor
|
|
self.img_dim_size = len(x.shape)
|
|
self.latent_image = latent_image
|
|
self.noise = noise
|
|
if n_steps is None:
|
|
n_steps = self.n_steps
|
|
out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
|
|
out = _pack_latents(out)
|
|
return out
|
|
def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
|
|
if IS_FLUX:
|
|
cfg_BIG = 1.0
|
|
|
|
def double_denoise(latents, t):
|
|
latents = _pack_latents(latents)
|
|
noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
|
|
if noise_pred == None: return None, None
|
|
predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
|
|
predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
|
|
if true_cfg_scale == cfg_BIG:
|
|
predict_big = predict_std
|
|
else:
|
|
predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
|
|
predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
|
|
return predict_std, predict_big
|
|
|
|
if len(sigma.shape) == 0:
|
|
sigma = torch.tensor([sigma.item()])
|
|
latent_mask = 1 - latent_mask
|
|
if IS_FLUX or IS_FLOW:
|
|
Flow_t = sigma
|
|
abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
|
|
VE_Sigma = Flow_t / (1 - Flow_t)
|
|
#print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
|
|
else:
|
|
VE_Sigma = sigma
|
|
abt = 1/( 1+VE_Sigma**2 )
|
|
Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
|
|
# VE_Sigma, abt, Flow_t = current_times
|
|
current_times = (VE_Sigma, abt, Flow_t)
|
|
|
|
step_size = self.step_size * (1 - abt)
|
|
step_size = self.add_none_dims(step_size)
|
|
# self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
|
|
# This is the replace step
|
|
# x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
|
|
|
|
noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
|
|
x = x * (1 - latent_mask) + noisy_image * latent_mask
|
|
|
|
if IS_FLUX or IS_FLOW:
|
|
x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
|
|
else:
|
|
x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
|
|
|
|
############ LanPaint Iterations Start ###############
|
|
# after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
|
|
args = None
|
|
for i in range(n_steps):
|
|
score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
|
|
if score_func is None: return None
|
|
x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
|
|
if IS_FLUX or IS_FLOW:
|
|
x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
|
|
else:
|
|
x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
|
|
############ LanPaint Iterations End ###############
|
|
# out is x_0
|
|
# out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
|
|
# out = out * (1-latent_mask) + self.latent_image * latent_mask
|
|
# return out
|
|
return x
|
|
|
|
def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
|
|
|
|
lamb = self.chara_lamb
|
|
if self.IS_FLUX or self.IS_FLOW:
|
|
# compute t for flow model, with a small epsilon compensating for numerical error.
|
|
x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
|
|
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
|
|
if x_0 is None: return None
|
|
else:
|
|
x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
|
|
x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
|
|
if x_0 is None: return None
|
|
|
|
score_x = -(x_t - x_0)
|
|
score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
|
|
return score_x * (1 - mask) + score_y * mask
|
|
def sigma_x(self, abt):
|
|
# the time scale for the x_t update
|
|
return abt**0
|
|
def sigma_y(self, abt):
|
|
beta = self.chara_beta * abt ** 0
|
|
return beta
|
|
|
|
def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
|
|
# prepare the step size and time parameters
|
|
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
|
|
step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
|
|
sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
|
|
# print('mask',mask.device)
|
|
if torch.mean(dtx) <= 0.:
|
|
return x_t, args
|
|
# -------------------------------------------------------------------------
|
|
# Compute the Langevin dynamics update in variance perserving notation
|
|
# -------------------------------------------------------------------------
|
|
#x0 = self.x0_evalutation(x_t, score, sigma, args)
|
|
#C = abt**0.5 * x0 / (1-abt)
|
|
A = A_x * (1-mask) + A_y * mask
|
|
D = D_x * (1-mask) + D_y * mask
|
|
dt = dtx * (1-mask) + dty * mask
|
|
Gamma = Gamma_x * (1-mask) + Gamma_y * mask
|
|
|
|
|
|
def Coef_C(x_t):
|
|
x0 = self.x0_evalutation(x_t, score, sigma, args)
|
|
C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
|
|
return C
|
|
def advance_time(x_t, v, dt, Gamma, A, C, D):
|
|
dtype = x_t.dtype
|
|
with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
|
|
osc = StochasticHarmonicOscillator(Gamma, A, C, D )
|
|
x_t, v = osc.dynamics(x_t, v, dt )
|
|
x_t = x_t.to(dtype)
|
|
v = v.to(dtype)
|
|
return x_t, v
|
|
if args is None:
|
|
#v = torch.zeros_like(x_t)
|
|
v = None
|
|
C = Coef_C(x_t)
|
|
#print(torch.squeeze(dtx), torch.squeeze(dty))
|
|
x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
|
|
else:
|
|
v, C = args
|
|
|
|
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
|
|
|
|
C_new = Coef_C(x_t)
|
|
v = v + Gamma**0.5 * ( C_new - C) *dt
|
|
|
|
x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
|
|
|
|
C = C_new
|
|
|
|
return x_t, (v, C)
|
|
|
|
def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
|
|
# -------------------------------------------------------------------------
|
|
# Unpack current times parameters (sigma and abt)
|
|
sigma, abt, flow_t = current_times
|
|
sigma = self.add_none_dims(sigma)
|
|
abt = self.add_none_dims(abt)
|
|
# Compute time step (dtx, dty) for x and y branches.
|
|
dtx = 2 * step_size * sigma_x
|
|
dty = 2 * step_size * sigma_y
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Define friction parameter Gamma_hat for each branch.
|
|
# Using dtx**0 provides a tensor of the proper device/dtype.
|
|
|
|
Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
|
|
Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
|
|
#print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
|
|
# adjust dt to match denoise-addnoise steps sizes
|
|
Gamma_hat_x /= 2.
|
|
Gamma_hat_y /= 2.
|
|
A_t_x = (1) / ( 1 - abt ) * dtx / 2
|
|
A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
|
|
|
|
|
|
A_x = A_t_x / (dtx/2)
|
|
A_y = A_t_y / (dty/2)
|
|
Gamma_x = Gamma_hat_x / (dtx/2)
|
|
Gamma_y = Gamma_hat_y / (dty/2)
|
|
|
|
#D_x = (2 * (1 + sigma**2) )**0.5
|
|
#D_y = (2 * (1 + sigma**2) )**0.5
|
|
D_x = (2 * abt**0 )**0.5
|
|
D_y = (2 * abt**0 )**0.5
|
|
return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
|
|
|
|
|
|
|
|
def x0_evalutation(self, x_t, score, sigma, args):
|
|
x0 = x_t + score(x_t)
|
|
return x0 |