mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
fixed causvid scheduler
This commit is contained in:
parent
49aaa12689
commit
a356c6af4b
@ -72,10 +72,7 @@ class model_factory:
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
rng = torch.Generator(device="cuda")
|
||||
if seed is None:
|
||||
seed = rng.seed()
|
||||
|
||||
device="cuda"
|
||||
if input_ref_images != None and len(input_ref_images) > 0:
|
||||
image_ref = input_ref_images[0]
|
||||
w, h = image_ref.size
|
||||
@ -91,7 +88,7 @@ class model_factory:
|
||||
target_height=height,
|
||||
bs=batch_size,
|
||||
seed=seed,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
|
||||
inp.pop("img_cond_orig")
|
||||
@ -103,7 +100,7 @@ class model_factory:
|
||||
if x==None: return None
|
||||
# decode latents to pixel space
|
||||
x = unpack_latent(x)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
||||
x = self.vae.decode(x)
|
||||
|
||||
x = x.clamp(-1, 1)
|
||||
|
||||
@ -30,8 +30,8 @@ def get_noise(
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
||||
).to(device)
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
|
||||
@ -15,13 +15,7 @@ from torch import nn
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention, unbind, fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
@ -23,14 +23,7 @@ from .mlp import Mlp
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
from xformers.ops import scaled_index_add, index_select_cat
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
# logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
@ -53,7 +53,7 @@ class FlowMatchScheduler():
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
return prev_sample
|
||||
return [prev_sample]
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
"""
|
||||
|
||||
@ -493,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False):
|
||||
|
||||
except ffmpeg.Error as e:
|
||||
print(f"FFmpeg error during audio extraction: {e}")
|
||||
return []
|
||||
return 0 if query_only else []
|
||||
except Exception as e:
|
||||
print(f"Error during audio extraction: {e}")
|
||||
return []
|
||||
return 0 if query_only else []
|
||||
|
||||
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user