mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			256 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			256 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
# ==============================================================================
 | 
						|
#
 | 
						|
# Modified from diffusers==0.29.2
 | 
						|
#
 | 
						|
# ==============================================================================
 | 
						|
 | 
						|
from dataclasses import dataclass
 | 
						|
from typing import Optional, Tuple, Union
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
 | 
						|
from diffusers.configuration_utils import ConfigMixin, register_to_config
 | 
						|
from diffusers.utils import BaseOutput, logging
 | 
						|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
 | 
						|
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
 | 
						|
    """
 | 
						|
    Output class for the scheduler's `step` function output.
 | 
						|
 | 
						|
    Args:
 | 
						|
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
 | 
						|
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
 | 
						|
            denoising loop.
 | 
						|
    """
 | 
						|
 | 
						|
    prev_sample: torch.FloatTensor
 | 
						|
 | 
						|
 | 
						|
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
 | 
						|
    """
 | 
						|
    Euler scheduler.
 | 
						|
 | 
						|
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
 | 
						|
    methods the library implements for all schedulers such as loading and saving.
 | 
						|
 | 
						|
    Args:
 | 
						|
        num_train_timesteps (`int`, defaults to 1000):
 | 
						|
            The number of diffusion steps to train the model.
 | 
						|
        timestep_spacing (`str`, defaults to `"linspace"`):
 | 
						|
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
 | 
						|
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
 | 
						|
        shift (`float`, defaults to 1.0):
 | 
						|
            The shift value for the timestep schedule.
 | 
						|
        reverse (`bool`, defaults to `True`):
 | 
						|
            Whether to reverse the timestep schedule.
 | 
						|
    """
 | 
						|
 | 
						|
    _compatibles = []
 | 
						|
    order = 1
 | 
						|
 | 
						|
    @register_to_config
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        num_train_timesteps: int = 1000,
 | 
						|
        shift: float = 1.0,
 | 
						|
        reverse: bool = True,
 | 
						|
        solver: str = "euler",
 | 
						|
        n_tokens: Optional[int] = None,
 | 
						|
    ):
 | 
						|
        sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
 | 
						|
 | 
						|
        if not reverse:
 | 
						|
            sigmas = sigmas.flip(0)
 | 
						|
 | 
						|
        self.sigmas = sigmas
 | 
						|
        # the value fed to model
 | 
						|
        self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
 | 
						|
 | 
						|
        self._step_index = None
 | 
						|
        self._begin_index = None
 | 
						|
 | 
						|
        self.supported_solver = ["euler"]
 | 
						|
        if solver not in self.supported_solver:
 | 
						|
            raise ValueError(
 | 
						|
                f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
 | 
						|
            )
 | 
						|
 | 
						|
    @property
 | 
						|
    def step_index(self):
 | 
						|
        """
 | 
						|
        The index counter for current timestep. It will increase 1 after each scheduler step.
 | 
						|
        """
 | 
						|
        return self._step_index
 | 
						|
 | 
						|
    @property
 | 
						|
    def begin_index(self):
 | 
						|
        """
 | 
						|
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
 | 
						|
        """
 | 
						|
        return self._begin_index
 | 
						|
 | 
						|
    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
 | 
						|
    def set_begin_index(self, begin_index: int = 0):
 | 
						|
        """
 | 
						|
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
 | 
						|
 | 
						|
        Args:
 | 
						|
            begin_index (`int`):
 | 
						|
                The begin index for the scheduler.
 | 
						|
        """
 | 
						|
        self._begin_index = begin_index
 | 
						|
 | 
						|
    def _sigma_to_t(self, sigma):
 | 
						|
        return sigma * self.config.num_train_timesteps
 | 
						|
 | 
						|
    def set_timesteps(
 | 
						|
        self,
 | 
						|
        num_inference_steps: int,
 | 
						|
        device: Union[str, torch.device] = None,
 | 
						|
        n_tokens: int = None,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
 | 
						|
 | 
						|
        Args:
 | 
						|
            num_inference_steps (`int`):
 | 
						|
                The number of diffusion steps used when generating samples with a pre-trained model.
 | 
						|
            device (`str` or `torch.device`, *optional*):
 | 
						|
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
 | 
						|
            n_tokens (`int`, *optional*):
 | 
						|
                Number of tokens in the input sequence.
 | 
						|
        """
 | 
						|
        self.num_inference_steps = num_inference_steps
 | 
						|
        
 | 
						|
        sigmas = torch.linspace(1, 0, num_inference_steps + 1)
 | 
						|
        sigmas = self.sd3_time_shift(sigmas)
 | 
						|
 | 
						|
        if not self.config.reverse:
 | 
						|
            sigmas = 1 - sigmas
 | 
						|
 | 
						|
        self.sigmas = sigmas
 | 
						|
        self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
 | 
						|
            dtype=torch.float32, device=device
 | 
						|
        )
 | 
						|
 | 
						|
        # Reset step index
 | 
						|
        self._step_index = None
 | 
						|
 | 
						|
    def index_for_timestep(self, timestep, schedule_timesteps=None):
 | 
						|
        if schedule_timesteps is None:
 | 
						|
            schedule_timesteps = self.timesteps
 | 
						|
 | 
						|
        indices = (schedule_timesteps == timestep).nonzero()
 | 
						|
 | 
						|
        # The sigma index that is taken for the **very** first `step`
 | 
						|
        # is always the second index (or the last index if there is only 1)
 | 
						|
        # This way we can ensure we don't accidentally skip a sigma in
 | 
						|
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
 | 
						|
        pos = 1 if len(indices) > 1 else 0
 | 
						|
 | 
						|
        return indices[pos].item()
 | 
						|
 | 
						|
    def _init_step_index(self, timestep):
 | 
						|
        if self.begin_index is None:
 | 
						|
            if isinstance(timestep, torch.Tensor):
 | 
						|
                timestep = timestep.to(self.timesteps.device)
 | 
						|
            self._step_index = self.index_for_timestep(timestep)
 | 
						|
        else:
 | 
						|
            self._step_index = self._begin_index
 | 
						|
 | 
						|
    def scale_model_input(
 | 
						|
        self, sample: torch.Tensor, timestep: Optional[int] = None
 | 
						|
    ) -> torch.Tensor:
 | 
						|
        return sample
 | 
						|
 | 
						|
    def sd3_time_shift(self, t: torch.Tensor):
 | 
						|
        return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
 | 
						|
 | 
						|
    def step(
 | 
						|
        self,
 | 
						|
        model_output: torch.FloatTensor,
 | 
						|
        timestep: Union[float, torch.FloatTensor],
 | 
						|
        sample: torch.FloatTensor,
 | 
						|
        return_dict: bool = True,
 | 
						|
    ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
 | 
						|
        """
 | 
						|
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
 | 
						|
        process from the learned model outputs (most often the predicted noise).
 | 
						|
 | 
						|
        Args:
 | 
						|
            model_output (`torch.FloatTensor`):
 | 
						|
                The direct output from learned diffusion model.
 | 
						|
            timestep (`float`):
 | 
						|
                The current discrete timestep in the diffusion chain.
 | 
						|
            sample (`torch.FloatTensor`):
 | 
						|
                A current instance of a sample created by the diffusion process.
 | 
						|
            generator (`torch.Generator`, *optional*):
 | 
						|
                A random number generator.
 | 
						|
            n_tokens (`int`, *optional*):
 | 
						|
                Number of tokens in the input sequence.
 | 
						|
            return_dict (`bool`):
 | 
						|
                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
 | 
						|
                tuple.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
 | 
						|
                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
 | 
						|
                returned, otherwise a tuple is returned where the first element is the sample tensor.
 | 
						|
        """
 | 
						|
 | 
						|
        if (
 | 
						|
            isinstance(timestep, int)
 | 
						|
            or isinstance(timestep, torch.IntTensor)
 | 
						|
            or isinstance(timestep, torch.LongTensor)
 | 
						|
        ):
 | 
						|
            raise ValueError(
 | 
						|
                (
 | 
						|
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
 | 
						|
                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
 | 
						|
                    " one of the `scheduler.timesteps` as a timestep."
 | 
						|
                ),
 | 
						|
            )
 | 
						|
 | 
						|
        if self.step_index is None:
 | 
						|
            self._init_step_index(timestep)
 | 
						|
 | 
						|
        # Upcast to avoid precision issues when computing prev_sample
 | 
						|
        sample = sample.to(torch.float32)
 | 
						|
 | 
						|
        dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
 | 
						|
 | 
						|
        if self.config.solver == "euler":
 | 
						|
            prev_sample = sample + model_output.to(torch.float32) * dt
 | 
						|
        else:
 | 
						|
            raise ValueError(
 | 
						|
                f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
 | 
						|
            )
 | 
						|
 | 
						|
        # upon completion increase step index by one
 | 
						|
        self._step_index += 1
 | 
						|
 | 
						|
        if not return_dict:
 | 
						|
            return (prev_sample,)
 | 
						|
 | 
						|
        return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        return self.config.num_train_timesteps
 |