mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
fixes
This commit is contained in:
parent
e7f17868be
commit
7676ded155
@ -18,7 +18,7 @@ numpy>=1.23.5,<2
|
||||
einops
|
||||
moviepy==1.0.3
|
||||
mmgp==3.4.9
|
||||
peft==0.14.0
|
||||
peft==0.15.0
|
||||
mutagen
|
||||
pydantic==2.10.6
|
||||
decord
|
||||
|
||||
@ -1070,10 +1070,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
# Vace embeddings
|
||||
c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
# c = c[0]
|
||||
c = [ [sub_c] for sub_c in c]
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
hints_list = [ c ]* len(x_list)
|
||||
hints_list = [ [ [sub_c] for sub_c in c] for _ in range(len(x_list)) ]
|
||||
del c
|
||||
|
||||
should_calc = True
|
||||
|
||||
@ -133,9 +133,10 @@ class WanT2V:
|
||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
||||
|
||||
if overlapped_latents != None :
|
||||
if overlapped_latents != None and False :
|
||||
# inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
|
||||
inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
|
||||
for t in inactive:
|
||||
t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
|
||||
|
||||
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||
@ -162,7 +163,7 @@ class WanT2V:
|
||||
result_masks = []
|
||||
for mask, refs in zip(masks, ref_images):
|
||||
c, depth, height, width = mask.shape
|
||||
new_depth = int((depth + 3) // self.vae_stride[0])
|
||||
new_depth = int((depth + 3) // self.vae_stride[0]) # nb latents token without (ref tokens not included)
|
||||
height = 2 * (int(height) // (self.vae_stride[1] * 2))
|
||||
width = 2 * (int(width) // (self.vae_stride[2] * 2))
|
||||
|
||||
@ -189,7 +190,7 @@ class WanT2V:
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None):
|
||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
|
||||
from wan.utils.utils import save_image
|
||||
ref_width, ref_height = ref_img.size
|
||||
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
||||
@ -212,15 +213,24 @@ class WanT2V:
|
||||
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
if outpainting_dims != None:
|
||||
white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
||||
white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
|
||||
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
||||
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
|
||||
else:
|
||||
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
|
||||
ref_img = white_canvas
|
||||
return ref_img.to(device)
|
||||
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
|
||||
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
|
||||
ref_img = canvas
|
||||
canvas = None
|
||||
if return_mask:
|
||||
if outpainting_dims != None:
|
||||
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
|
||||
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
|
||||
else:
|
||||
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
|
||||
canvas[:, :, top:top + new_height, left:left + new_width] = 0
|
||||
canvas = canvas.to(device)
|
||||
return ref_img.to(device), canvas
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||
image_sizes = []
|
||||
trim_video = len(keep_frames)
|
||||
def conv_tensor(t, device):
|
||||
@ -270,16 +280,22 @@ class WanT2V:
|
||||
|
||||
for k, frame in enumerate(inject_frames):
|
||||
if frame != None:
|
||||
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims)
|
||||
src_mask[i][:, k:k+1] = 0
|
||||
src_video[i][:, k:k+1], src_mask[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||
|
||||
|
||||
self.background_mask = None
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
image_size = image_sizes[i]
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None and not torch.is_tensor(ref_img):
|
||||
src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
|
||||
if j==0 and any_background_ref:
|
||||
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
|
||||
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
|
||||
else:
|
||||
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
|
||||
if self.background_mask != None:
|
||||
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
|
||||
@ -408,11 +424,20 @@ class WanT2V:
|
||||
input_frames = [u.to(self.device) for u in input_frames]
|
||||
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
||||
input_masks = [u.to(self.device) for u in input_masks]
|
||||
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
|
||||
previous_latents = None
|
||||
# if overlapped_latents != None:
|
||||
# input_ref_images = [u[-1:] for u in input_ref_images]
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||
if self.background_mask != None:
|
||||
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
||||
mbg = self.vace_encode_masks(self.background_mask, None)
|
||||
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
||||
zz0[:, 0:1] = zzbg
|
||||
mm0[:, 0:1] = mmbg
|
||||
|
||||
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
@ -499,25 +524,15 @@ class WanT2V:
|
||||
self.model.setup_chipmunk()
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
|
||||
timestep = [t]
|
||||
if overlapped_latents != None :
|
||||
# overlap_noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
|
||||
overlap_noise_factor = overlap_noise / 1000
|
||||
# overlap_noise_factor = (1000-t )/ 1000 # overlap_noise / 1000
|
||||
# latent_noise_factor = 1 #max(min(1, (t - overlap_noise) / 1000 ),0)
|
||||
latent_noise_factor = t / 1000
|
||||
for zz, zz_r, ll in zip(z, z_reactive, [latents]):
|
||||
pass
|
||||
for zz, zz_r, ll in zip(z, z_reactive, [latents, None]): # extra None for second control net
|
||||
zz[0:16, ref_images_count:overlapped_latents_size + ref_images_count] = zz_r[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(zz_r[:, ref_images_count:] ) * overlap_noise_factor
|
||||
if ll != None:
|
||||
ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor
|
||||
|
||||
if conditioning_latents_size > 0 and overlap_noise > 0:
|
||||
pass
|
||||
overlap_noise_factor = overlap_noise / 1000
|
||||
# latents[:, conditioning_latents_size + ref_images_count:] = latents[:, conditioning_latents_size + ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(latents[:, conditioning_latents_size + ref_images_count:]) * overlap_noise_factor
|
||||
# timestep = [torch.tensor([t.item()] * (conditioning_latents_size + ref_images_count) + [t.item() - overlap_noise]*(target_shape[1] - conditioning_latents_size - ref_images_count))]
|
||||
|
||||
if target_camera != None:
|
||||
latent_model_input = torch.cat([latents, source_latents], dim=1)
|
||||
else:
|
||||
|
||||
27
wgp.py
27
wgp.py
@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.4.9"
|
||||
WanGP_version = "6.3"
|
||||
WanGP_version = "6.31"
|
||||
settings_version = 2
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
|
||||
@ -3351,6 +3351,7 @@ def generate_video(
|
||||
|
||||
original_image_refs = image_refs
|
||||
frames_to_inject = []
|
||||
any_background_ref = False
|
||||
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
|
||||
|
||||
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
|
||||
@ -3368,9 +3369,9 @@ def generate_video(
|
||||
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
|
||||
default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
|
||||
fit_canvas = None
|
||||
|
||||
if len(image_refs) > nb_frames_positions:
|
||||
if hunyuan_avatar: remove_background_images_ref = 0
|
||||
any_background_ref = remove_background_images_ref != 1
|
||||
if remove_background_images_ref > 0:
|
||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||
@ -3620,9 +3621,17 @@ def generate_video(
|
||||
fit_into_canvas = sample_fit_canvas,
|
||||
inject_frames= frames_to_inject_parsed,
|
||||
outpainting_dims = outpainting_dims,
|
||||
any_background_ref = any_background_ref
|
||||
)
|
||||
if len(frames_to_inject_parsed):
|
||||
refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:]
|
||||
if len(frames_to_inject_parsed) or any_background_ref:
|
||||
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
|
||||
if any_background_ref:
|
||||
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
|
||||
else:
|
||||
new_image_refs += image_refs[nb_frames_positions:]
|
||||
refresh_preview["image_refs"] = new_image_refs
|
||||
new_image_refs = None
|
||||
|
||||
if sample_fit_canvas != None:
|
||||
image_size = src_video[0].shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
@ -5315,9 +5324,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
("Use Vace raw format", "V"),
|
||||
("Keep Unchanged", "UV"),
|
||||
("Transfer Human Motion & Depth", "PDV"),
|
||||
("Transfer Human Motion & Shape", "PSV"),
|
||||
("Transfer Human Motion & Shapes", "PSV"),
|
||||
("Transfer Human Motion & Flow", "PFV"),
|
||||
("Transfer Depth & Shape", "DSV"),
|
||||
("Transfer Depth & Shapes", "DSV"),
|
||||
("Transfer Depth & Flow", "DFV"),
|
||||
("Transfer Shapes & Flow", "SFV"),
|
||||
],
|
||||
@ -5392,7 +5401,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
|
||||
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
|
||||
with gr.Group():
|
||||
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
|
||||
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Background or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
|
||||
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
|
||||
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
|
||||
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
|
||||
@ -5450,7 +5459,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
wizard_variables = "\n".join(variables)
|
||||
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
|
||||
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
|
||||
with gr.Column(not advanced_prompt) as prompt_column_wizard:
|
||||
with gr.Column(visible=not advanced_prompt) as prompt_column_wizard:
|
||||
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label=wizard_prompt_label, value=default_wizard_prompt, lines=3)
|
||||
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
||||
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
||||
@ -5759,7 +5768,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row,
|
||||
video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs,
|
||||
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
|
||||
video_guide_outpainting_checkbox, video_guide_outpainting_row] # show_advanced presets_column,
|
||||
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced] # presets_column,
|
||||
if update_form:
|
||||
locals_dict = locals()
|
||||
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user