mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			240 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			240 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
from .utils import *
 | 
						|
from functools import partial
 | 
						|
 | 
						|
# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
 | 
						|
 | 
						|
def _pack_latents(latents):
 | 
						|
    batch_size, num_channels_latents, _, height, width = latents.shape 
 | 
						|
 | 
						|
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
 | 
						|
    latents = latents.permute(0, 2, 4, 1, 3, 5)
 | 
						|
    latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
 | 
						|
 | 
						|
    return latents
 | 
						|
 | 
						|
def _unpack_latents(latents, height, width, vae_scale_factor=8):
 | 
						|
    batch_size, num_patches, channels = latents.shape
 | 
						|
 | 
						|
    height = 2 * (int(height) // (vae_scale_factor * 2))
 | 
						|
    width = 2 * (int(width) // (vae_scale_factor * 2))
 | 
						|
 | 
						|
    latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
 | 
						|
    latents = latents.permute(0, 3, 1, 4, 2, 5)
 | 
						|
 | 
						|
    latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
 | 
						|
 | 
						|
    return latents
 | 
						|
 | 
						|
class LanPaint():
 | 
						|
    def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
 | 
						|
        self.n_steps = NSteps
 | 
						|
        self.chara_lamb = Lambda
 | 
						|
        self.IS_FLUX = IS_FLUX
 | 
						|
        self.IS_FLOW = IS_FLOW
 | 
						|
        self.step_size = StepSize
 | 
						|
        self.friction = Friction
 | 
						|
        self.chara_beta = Beta
 | 
						|
        self.img_dim_size = None
 | 
						|
    def add_none_dims(self, array):
 | 
						|
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
 | 
						|
        index = (slice(None),) + (None,) * (self.img_dim_size-1)
 | 
						|
        return array[index]
 | 
						|
    def remove_none_dims(self, array):
 | 
						|
        # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
 | 
						|
        index = (slice(None),) + (0,) * (self.img_dim_size-1)
 | 
						|
        return array[index]
 | 
						|
    def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
 | 
						|
        latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
						|
        noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
						|
        x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
						|
        latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
 | 
						|
        self.height = height
 | 
						|
        self.width = width
 | 
						|
        self.vae_scale_factor = vae_scale_factor
 | 
						|
        self.img_dim_size = len(x.shape)
 | 
						|
        self.latent_image = latent_image
 | 
						|
        self.noise = noise
 | 
						|
        if n_steps is None:
 | 
						|
            n_steps = self.n_steps
 | 
						|
        out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
 | 
						|
        out = _pack_latents(out)
 | 
						|
        return out
 | 
						|
    def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG,  x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
 | 
						|
        if IS_FLUX:
 | 
						|
            cfg_BIG = 1.0
 | 
						|
 | 
						|
        def double_denoise(latents, t):
 | 
						|
            latents = _pack_latents(latents)
 | 
						|
            noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
 | 
						|
            if noise_pred == None: return None, None
 | 
						|
            predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
 | 
						|
            predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
 | 
						|
            if true_cfg_scale ==  cfg_BIG:
 | 
						|
                predict_big = predict_std
 | 
						|
            else:
 | 
						|
                predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
 | 
						|
                predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
 | 
						|
            return predict_std, predict_big
 | 
						|
        
 | 
						|
        if len(sigma.shape) == 0:
 | 
						|
            sigma = torch.tensor([sigma.item()])
 | 
						|
        latent_mask = 1 - latent_mask
 | 
						|
        if IS_FLUX or IS_FLOW:
 | 
						|
            Flow_t = sigma
 | 
						|
            abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
 | 
						|
            VE_Sigma = Flow_t / (1 - Flow_t)
 | 
						|
            #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
 | 
						|
        else:
 | 
						|
            VE_Sigma = sigma 
 | 
						|
            abt = 1/( 1+VE_Sigma**2 )
 | 
						|
            Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5  )
 | 
						|
        # VE_Sigma, abt, Flow_t = current_times
 | 
						|
        current_times =  (VE_Sigma, abt, Flow_t)
 | 
						|
        
 | 
						|
        step_size = self.step_size * (1 - abt)
 | 
						|
        step_size = self.add_none_dims(step_size)
 | 
						|
        # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
 | 
						|
        # This is the replace step
 | 
						|
        # x = x * (1 - latent_mask) +  self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
 | 
						|
 | 
						|
        noisy_image  = self.latent_image  * (1.0 - sigma) + self.noise * sigma 
 | 
						|
        x = x * (1 - latent_mask) +  noisy_image * latent_mask
 | 
						|
 | 
						|
        if IS_FLUX or IS_FLOW:
 | 
						|
            x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
 | 
						|
        else:
 | 
						|
            x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
 | 
						|
 | 
						|
        ############ LanPaint Iterations Start ###############
 | 
						|
        # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
 | 
						|
        args = None
 | 
						|
        for i in range(n_steps):
 | 
						|
            score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
 | 
						|
            if score_func is None: return None
 | 
						|
            x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)  
 | 
						|
        if IS_FLUX or IS_FLOW:
 | 
						|
            x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
 | 
						|
        else:
 | 
						|
            x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
 | 
						|
        ############ LanPaint Iterations End ###############
 | 
						|
        # out is x_0
 | 
						|
        # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
 | 
						|
        # out = out * (1-latent_mask) + self.latent_image * latent_mask
 | 
						|
        # return out
 | 
						|
        return x
 | 
						|
 | 
						|
    def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
 | 
						|
        
 | 
						|
        lamb = self.chara_lamb
 | 
						|
        if self.IS_FLUX or self.IS_FLOW:
 | 
						|
            # compute t for flow model, with a small epsilon compensating for numerical error.
 | 
						|
            x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
 | 
						|
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
 | 
						|
            if x_0 is None: return None
 | 
						|
        else:
 | 
						|
            x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
 | 
						|
            x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
 | 
						|
            if x_0 is None: return None
 | 
						|
 | 
						|
        score_x = -(x_t - x_0)
 | 
						|
        score_y =  - (1 + lamb) * ( x_t - y )  + lamb * (x_t - x_0_BIG)  
 | 
						|
        return score_x * (1 - mask) + score_y * mask
 | 
						|
    def sigma_x(self, abt):
 | 
						|
        # the time scale for the x_t update
 | 
						|
        return abt**0
 | 
						|
    def sigma_y(self, abt):
 | 
						|
        beta = self.chara_beta * abt ** 0
 | 
						|
        return beta
 | 
						|
 | 
						|
    def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
 | 
						|
        # prepare the step size and time parameters
 | 
						|
        with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
 | 
						|
            step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
 | 
						|
            sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
 | 
						|
        # print('mask',mask.device)
 | 
						|
        if torch.mean(dtx) <= 0.:
 | 
						|
            return x_t, args
 | 
						|
        # -------------------------------------------------------------------------
 | 
						|
        # Compute the Langevin dynamics update in variance perserving notation
 | 
						|
        # -------------------------------------------------------------------------
 | 
						|
        #x0 = self.x0_evalutation(x_t, score, sigma, args)
 | 
						|
        #C = abt**0.5 * x0 / (1-abt)
 | 
						|
        A = A_x * (1-mask) + A_y * mask
 | 
						|
        D = D_x * (1-mask) + D_y * mask
 | 
						|
        dt = dtx * (1-mask) + dty * mask
 | 
						|
        Gamma = Gamma_x * (1-mask) + Gamma_y * mask
 | 
						|
 | 
						|
 | 
						|
        def Coef_C(x_t):
 | 
						|
            x0 = self.x0_evalutation(x_t, score, sigma, args)
 | 
						|
            C = (abt**0.5 * x0  - x_t )/ (1-abt) + A * x_t
 | 
						|
            return C
 | 
						|
        def advance_time(x_t, v, dt, Gamma, A, C, D):
 | 
						|
            dtype = x_t.dtype
 | 
						|
            with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
 | 
						|
                osc = StochasticHarmonicOscillator(Gamma, A, C, D )
 | 
						|
                x_t, v = osc.dynamics(x_t, v, dt )
 | 
						|
            x_t = x_t.to(dtype)
 | 
						|
            v = v.to(dtype)
 | 
						|
            return x_t, v
 | 
						|
        if args is None:
 | 
						|
            #v = torch.zeros_like(x_t)
 | 
						|
            v = None
 | 
						|
            C = Coef_C(x_t)
 | 
						|
            #print(torch.squeeze(dtx), torch.squeeze(dty))
 | 
						|
            x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
 | 
						|
        else:
 | 
						|
            v, C = args
 | 
						|
 | 
						|
            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
 | 
						|
 | 
						|
            C_new = Coef_C(x_t)
 | 
						|
            v = v + Gamma**0.5 * ( C_new - C) *dt
 | 
						|
 | 
						|
            x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
 | 
						|
 | 
						|
            C = C_new
 | 
						|
  
 | 
						|
        return x_t, (v, C)
 | 
						|
 | 
						|
    def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
 | 
						|
        # -------------------------------------------------------------------------
 | 
						|
        # Unpack current times parameters (sigma and abt)
 | 
						|
        sigma, abt, flow_t = current_times
 | 
						|
        sigma = self.add_none_dims(sigma)
 | 
						|
        abt = self.add_none_dims(abt)
 | 
						|
        # Compute time step (dtx, dty) for x and y branches.
 | 
						|
        dtx = 2 * step_size * sigma_x
 | 
						|
        dty = 2 * step_size * sigma_y
 | 
						|
        
 | 
						|
        # -------------------------------------------------------------------------
 | 
						|
        # Define friction parameter Gamma_hat for each branch.
 | 
						|
        # Using dtx**0 provides a tensor of the proper device/dtype.
 | 
						|
 | 
						|
        Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
 | 
						|
        Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
 | 
						|
        #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
 | 
						|
        # adjust dt to match denoise-addnoise steps sizes
 | 
						|
        Gamma_hat_x /= 2.
 | 
						|
        Gamma_hat_y /= 2.
 | 
						|
        A_t_x = (1) / ( 1 - abt ) * dtx / 2
 | 
						|
        A_t_y =  (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
 | 
						|
 | 
						|
 | 
						|
        A_x = A_t_x / (dtx/2)
 | 
						|
        A_y = A_t_y / (dty/2)
 | 
						|
        Gamma_x = Gamma_hat_x / (dtx/2)
 | 
						|
        Gamma_y = Gamma_hat_y / (dty/2)
 | 
						|
 | 
						|
        #D_x = (2 * (1 + sigma**2) )**0.5
 | 
						|
        #D_y = (2 * (1 + sigma**2) )**0.5
 | 
						|
        D_x = (2 * abt**0 )**0.5
 | 
						|
        D_y = (2 * abt**0 )**0.5
 | 
						|
        return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
 | 
						|
 | 
						|
 | 
						|
 | 
						|
    def x0_evalutation(self, x_t, score, sigma, args):
 | 
						|
        x0 = x_t + score(x_t)
 | 
						|
        return x0 |