mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
fixed causvid scheduler
This commit is contained in:
parent
49aaa12689
commit
a356c6af4b
@ -72,10 +72,7 @@ class model_factory:
|
|||||||
if self._interrupt:
|
if self._interrupt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
rng = torch.Generator(device="cuda")
|
device="cuda"
|
||||||
if seed is None:
|
|
||||||
seed = rng.seed()
|
|
||||||
|
|
||||||
if input_ref_images != None and len(input_ref_images) > 0:
|
if input_ref_images != None and len(input_ref_images) > 0:
|
||||||
image_ref = input_ref_images[0]
|
image_ref = input_ref_images[0]
|
||||||
w, h = image_ref.size
|
w, h = image_ref.size
|
||||||
@ -91,7 +88,7 @@ class model_factory:
|
|||||||
target_height=height,
|
target_height=height,
|
||||||
bs=batch_size,
|
bs=batch_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
device="cuda",
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
inp.pop("img_cond_orig")
|
inp.pop("img_cond_orig")
|
||||||
@ -103,7 +100,7 @@ class model_factory:
|
|||||||
if x==None: return None
|
if x==None: return None
|
||||||
# decode latents to pixel space
|
# decode latents to pixel space
|
||||||
x = unpack_latent(x)
|
x = unpack_latent(x)
|
||||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
||||||
x = self.vae.decode(x)
|
x = self.vae.decode(x)
|
||||||
|
|
||||||
x = x.clamp(-1, 1)
|
x = x.clamp(-1, 1)
|
||||||
|
|||||||
@ -30,8 +30,8 @@ def get_noise(
|
|||||||
2 * math.ceil(height / 16),
|
2 * math.ceil(height / 16),
|
||||||
2 * math.ceil(width / 16),
|
2 * math.ceil(width / 16),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
generator=torch.Generator(device="cuda").manual_seed(seed),
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||||
|
|||||||
@ -15,13 +15,7 @@ from torch import nn
|
|||||||
|
|
||||||
logger = logging.getLogger("dinov2")
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
try:
|
XFORMERS_AVAILABLE = False
|
||||||
from xformers.ops import memory_efficient_attention, unbind, fmha
|
|
||||||
|
|
||||||
XFORMERS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("xFormers not available")
|
|
||||||
XFORMERS_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
|
|||||||
@ -23,14 +23,7 @@ from .mlp import Mlp
|
|||||||
logger = logging.getLogger("dinov2")
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
try:
|
XFORMERS_AVAILABLE = False
|
||||||
from xformers.ops import fmha
|
|
||||||
from xformers.ops import scaled_index_add, index_select_cat
|
|
||||||
|
|
||||||
XFORMERS_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
# logger.warning("xFormers not available")
|
|
||||||
XFORMERS_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class FlowMatchScheduler():
|
|||||||
else:
|
else:
|
||||||
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
|
sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1)
|
||||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||||
return prev_sample
|
return [prev_sample]
|
||||||
|
|
||||||
def add_noise(self, original_samples, noise, timestep):
|
def add_noise(self, original_samples, noise, timestep):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -493,10 +493,10 @@ def extract_audio_tracks(source_video, verbose=False, query_only= False):
|
|||||||
|
|
||||||
except ffmpeg.Error as e:
|
except ffmpeg.Error as e:
|
||||||
print(f"FFmpeg error during audio extraction: {e}")
|
print(f"FFmpeg error during audio extraction: {e}")
|
||||||
return []
|
return 0 if query_only else []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during audio extraction: {e}")
|
print(f"Error during audio extraction: {e}")
|
||||||
return []
|
return 0 if query_only else []
|
||||||
|
|
||||||
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
|
def combine_video_with_audio_tracks(target_video, audio_tracks, output_video, verbose=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user