From a356c6af4b72c5ed12b5ff7be2b27d97d059037d Mon Sep 17 00:00:00 2001 From: deepbeepmeep Date: Wed, 16 Jul 2025 18:09:06 +0200 Subject: [PATCH] fixed causvid scheduler --- flux/flux_main.py | 9 +++------ flux/sampling.py | 4 ++-- preprocessing/depth_anything_v2/layers/attention.py | 8 +------- preprocessing/depth_anything_v2/layers/block.py | 9 +-------- wan/utils/basic_flowmatch.py | 2 +- wan/utils/utils.py | 4 ++-- 6 files changed, 10 insertions(+), 26 deletions(-) diff --git a/flux/flux_main.py b/flux/flux_main.py index b782cc9..202eb44 100644 --- a/flux/flux_main.py +++ b/flux/flux_main.py @@ -72,10 +72,7 @@ class model_factory: if self._interrupt: return None - rng = torch.Generator(device="cuda") - if seed is None: - seed = rng.seed() - + device="cuda" if input_ref_images != None and len(input_ref_images) > 0: image_ref = input_ref_images[0] w, h = image_ref.size @@ -91,7 +88,7 @@ class model_factory: target_height=height, bs=batch_size, seed=seed, - device="cuda", + device=device, ) inp.pop("img_cond_orig") @@ -103,7 +100,7 @@ class model_factory: if x==None: return None # decode latents to pixel space x = unpack_latent(x) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) x = x.clamp(-1, 1) diff --git a/flux/sampling.py b/flux/sampling.py index 7581dea..5c137f1 100644 --- a/flux/sampling.py +++ b/flux/sampling.py @@ -30,8 +30,8 @@ def get_noise( 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype=dtype, - generator=torch.Generator(device="cuda").manual_seed(seed), - ).to(device) + generator=torch.Generator(device=device).manual_seed(seed), + ) def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: diff --git a/preprocessing/depth_anything_v2/layers/attention.py b/preprocessing/depth_anything_v2/layers/attention.py index f1cacb1..5a35c06 100644 --- a/preprocessing/depth_anything_v2/layers/attention.py +++ b/preprocessing/depth_anything_v2/layers/attention.py @@ -15,13 +15,7 @@ from torch import nn logger = logging.getLogger("dinov2") -try: - from xformers.ops import memory_efficient_attention, unbind, fmha - - XFORMERS_AVAILABLE = True -except ImportError: - logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False +XFORMERS_AVAILABLE = False class Attention(nn.Module): diff --git a/preprocessing/depth_anything_v2/layers/block.py b/preprocessing/depth_anything_v2/layers/block.py index a711a1f..8de1d57 100644 --- a/preprocessing/depth_anything_v2/layers/block.py +++ b/preprocessing/depth_anything_v2/layers/block.py @@ -23,14 +23,7 @@ from .mlp import Mlp logger = logging.getLogger("dinov2") -try: - from xformers.ops import fmha - from xformers.ops import scaled_index_add, index_select_cat - - XFORMERS_AVAILABLE = True -except ImportError: - # logger.warning("xFormers not available") - XFORMERS_AVAILABLE = False +XFORMERS_AVAILABLE = False class Block(nn.Module): diff --git a/wan/utils/basic_flowmatch.py b/wan/utils/basic_flowmatch.py index 591510b..ceb4657 100644 --- a/wan/utils/basic_flowmatch.py +++ b/wan/utils/basic_flowmatch.py @@ -53,7 +53,7 @@ class FlowMatchScheduler(): else: sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) prev_sample = sample + model_output * (sigma_ - sigma) - return prev_sample + return [prev_sample] def add_noise(self, original_samples, noise, timestep): """ diff --git a/wan/utils/utils.py b/wan/utils/utils.py index cbd34e9..53f3b73 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -493,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False): except ffmpeg.Error as e: print(f"FFmpeg error during audio extraction: {e}") - return [] + return 0 if query_only else [] except Exception as e: print(f"Error during audio extraction: {e}") - return [] + return 0 if query_only else [] def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False): """