fixed causvid scheduler

This commit is contained in:
deepbeepmeep 2025-07-16 18:09:06 +02:00
parent 49aaa12689
commit a356c6af4b
6 changed files with 10 additions and 26 deletions

View File

@ -72,10 +72,7 @@ class model_factory:
if self._interrupt:
return None
rng = torch.Generator(device="cuda")
if seed is None:
seed = rng.seed()
device="cuda"
if input_ref_images != None and len(input_ref_images) > 0:
image_ref = input_ref_images[0]
w, h = image_ref.size
@ -91,7 +88,7 @@ class model_factory:
target_height=height,
bs=batch_size,
seed=seed,
device="cuda",
device=device,
)
inp.pop("img_cond_orig")
@ -103,7 +100,7 @@ class model_factory:
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
x = x.clamp(-1, 1)

View File

@ -30,8 +30,8 @@ def get_noise(
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
generator=torch.Generator(device="cuda").manual_seed(seed),
).to(device)
generator=torch.Generator(device=device).manual_seed(seed),
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:

View File

@ -15,12 +15,6 @@ from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False

View File

@ -23,13 +23,6 @@ from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
# logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False

View File

@ -53,7 +53,7 @@ class FlowMatchScheduler():
else:
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
prev_sample = sample + model_output * (sigma_ - sigma)
return prev_sample
return [prev_sample]
def add_noise(self, original_samples, noise, timestep):
"""

View File

@ -493,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False):
except ffmpeg.Error as e:
print(f"FFmpeg error during audio extraction: {e}")
return []
return 0 if query_only else []
except Exception as e:
print(f"Error during audio extraction: {e}")
return []
return 0 if query_only else []
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
"""