From dd7b09c6d3d9b552f35062bb6a205c1921f2ce94 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 20 May 2025 20:15:56 +0200 Subject: [PATCH] added missing file --- wan/utils/basic_flowmatch.py | 83 ++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 wan/utils/basic_flowmatch.py diff --git a/wan/utils/basic_flowmatch.py b/wan/utils/basic_flowmatch.py new file mode 100644 index 0000000..591510b --- /dev/null +++ b/wan/utils/basic_flowmatch.py @@ -0,0 +1,83 @@ +""" +The following code is copied from https://github.com/modelscope/DiffSynth-Studio/blob/main/diffsynth/schedulers/flow_match.py +""" +import torch + + +class FlowMatchScheduler(): + + def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003 / 1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps) + + def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): + sigma_start = self.sigma_min + \ + (self.sigma_max - self.sigma_min) * denoising_strength + if self.extra_one_step: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / \ + (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp(-2 * ((x - num_inference_steps / 2) / + num_inference_steps) ** 2) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * \ + (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + def step(self, model_output, timestep, sample, to_final=False): + self.sigmas = self.sigmas.to(model_output.device) + self.timesteps = self.timesteps.to(model_output.device) + timestep_id = torch.argmin( + (self.timesteps - timestep).abs(), dim=0) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): + sigma_ = 1 if ( + self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) + prev_sample = sample + model_output * (sigma_ - sigma) + return prev_sample + + def add_noise(self, original_samples, noise, timestep): + """ + Diffusion forward corruption process. + Input: + - clean_latent: the clean latent with shape [B, C, H, W] + - noise: the noise with shape [B, C, H, W] + - timestep: the timestep with shape [B] + Output: the corrupted latent with shape [B, C, H, W] + """ + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + timestep_id = torch.argmin( + (self.timesteps - timestep.to(self.timesteps.device)).abs()) + weights = self.linear_timesteps_weights[timestep_id] + return weights \ No newline at end of file