import torch from .utils import * from functools import partial # Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) def _pack_latents(latents): batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents def _unpack_latents(latents, height, width, vae_scale_factor=8): batch_size, num_patches, channels = latents.shape height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) return latents class LanPaint(): def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): self.n_steps = NSteps self.chara_lamb = Lambda self.IS_FLUX = IS_FLUX self.IS_FLOW = IS_FLOW self.step_size = StepSize self.friction = Friction self.chara_beta = Beta self.img_dim_size = None def add_none_dims(self, array): # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times index = (slice(None),) + (None,) * (self.img_dim_size-1) return array[index] def remove_none_dims(self, array): # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times index = (slice(None),) + (0,) * (self.img_dim_size-1) return array[index] def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) self.height = height self.width = width self.vae_scale_factor = vae_scale_factor self.img_dim_size = len(x.shape) self.latent_image = latent_image self.noise = noise if n_steps is None: n_steps = self.n_steps out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) out = _pack_latents(out) return out def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): if IS_FLUX: cfg_BIG = 1.0 def double_denoise(latents, t): latents = _pack_latents(latents) noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) if noise_pred == None: return None, None predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) if true_cfg_scale == cfg_BIG: predict_big = predict_std else: predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) return predict_std, predict_big if len(sigma.shape) == 0: sigma = torch.tensor([sigma.item()]) latent_mask = 1 - latent_mask if IS_FLUX or IS_FLOW: Flow_t = sigma abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) VE_Sigma = Flow_t / (1 - Flow_t) #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) else: VE_Sigma = sigma abt = 1/( 1+VE_Sigma**2 ) Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) # VE_Sigma, abt, Flow_t = current_times current_times = (VE_Sigma, abt, Flow_t) step_size = self.step_size * (1 - abt) step_size = self.add_none_dims(step_size) # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values # This is the replace step # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma x = x * (1 - latent_mask) + noisy_image * latent_mask if IS_FLUX or IS_FLOW: x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) else: x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values ############ LanPaint Iterations Start ############### # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. args = None for i in range(n_steps): score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) if score_func is None: return None x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) if IS_FLUX or IS_FLOW: x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) else: x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values ############ LanPaint Iterations End ############### # out is x_0 # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) # out = out * (1-latent_mask) + self.latent_image * latent_mask # return out return x def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): lamb = self.chara_lamb if self.IS_FLUX or self.IS_FLOW: # compute t for flow model, with a small epsilon compensating for numerical error. x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) if x_0 is None: return None else: x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) if x_0 is None: return None score_x = -(x_t - x_0) score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) return score_x * (1 - mask) + score_y * mask def sigma_x(self, abt): # the time scale for the x_t update return abt**0 def sigma_y(self, abt): beta = self.chara_beta * abt ** 0 return beta def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): # prepare the step size and time parameters with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes # print('mask',mask.device) if torch.mean(dtx) <= 0.: return x_t, args # ------------------------------------------------------------------------- # Compute the Langevin dynamics update in variance perserving notation # ------------------------------------------------------------------------- #x0 = self.x0_evalutation(x_t, score, sigma, args) #C = abt**0.5 * x0 / (1-abt) A = A_x * (1-mask) + A_y * mask D = D_x * (1-mask) + D_y * mask dt = dtx * (1-mask) + dty * mask Gamma = Gamma_x * (1-mask) + Gamma_y * mask def Coef_C(x_t): x0 = self.x0_evalutation(x_t, score, sigma, args) C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t return C def advance_time(x_t, v, dt, Gamma, A, C, D): dtype = x_t.dtype with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): osc = StochasticHarmonicOscillator(Gamma, A, C, D ) x_t, v = osc.dynamics(x_t, v, dt ) x_t = x_t.to(dtype) v = v.to(dtype) return x_t, v if args is None: #v = torch.zeros_like(x_t) v = None C = Coef_C(x_t) #print(torch.squeeze(dtx), torch.squeeze(dty)) x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) else: v, C = args x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) C_new = Coef_C(x_t) v = v + Gamma**0.5 * ( C_new - C) *dt x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) C = C_new return x_t, (v, C) def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): # ------------------------------------------------------------------------- # Unpack current times parameters (sigma and abt) sigma, abt, flow_t = current_times sigma = self.add_none_dims(sigma) abt = self.add_none_dims(abt) # Compute time step (dtx, dty) for x and y branches. dtx = 2 * step_size * sigma_x dty = 2 * step_size * sigma_y # ------------------------------------------------------------------------- # Define friction parameter Gamma_hat for each branch. # Using dtx**0 provides a tensor of the proper device/dtype. Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) # adjust dt to match denoise-addnoise steps sizes Gamma_hat_x /= 2. Gamma_hat_y /= 2. A_t_x = (1) / ( 1 - abt ) * dtx / 2 A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 A_x = A_t_x / (dtx/2) A_y = A_t_y / (dty/2) Gamma_x = Gamma_hat_x / (dtx/2) Gamma_y = Gamma_hat_y / (dty/2) #D_x = (2 * (1 + sigma**2) )**0.5 #D_y = (2 * (1 + sigma**2) )**0.5 D_x = (2 * abt**0 )**0.5 D_y = (2 * abt**0 )**0.5 return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y def x0_evalutation(self, x_t, score, sigma, args): x0 = x_t + score(x_t) return x0