This commit is contained in:
DeepBeepMeep 2025-06-25 01:06:59 +02:00
parent e7f17868be
commit 7676ded155
4 changed files with 62 additions and 40 deletions

View File

@ -18,7 +18,7 @@ numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.4.9
peft==0.14.0
peft==0.15.0
mutagen
pydantic==2.10.6
decord

View File

@ -1070,10 +1070,8 @@ class WanModel(ModelMixin, ConfigMixin):
# Vace embeddings
c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
# c = c[0]
c = [ [sub_c] for sub_c in c]
kwargs['context_scale'] = vace_context_scale
hints_list = [ c ]* len(x_list)
hints_list = [ [ [sub_c] for sub_c in c] for _ in range(len(x_list)) ]
del c
should_calc = True

View File

@ -133,9 +133,10 @@ class WanT2V:
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae.encode(inactive, tile_size = tile_size)
if overlapped_latents != None :
if overlapped_latents != None and False :
# inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant
inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
for t in inactive:
t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents
reactive = self.vae.encode(reactive, tile_size = tile_size)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
@ -162,7 +163,7 @@ class WanT2V:
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.vae_stride[0])
new_depth = int((depth + 3) // self.vae_stride[0]) # nb latents token without (ref tokens not included)
height = 2 * (int(height) // (self.vae_stride[1] * 2))
width = 2 * (int(width) // (self.vae_stride[2] * 2))
@ -189,7 +190,7 @@ class WanT2V:
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None):
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
from wan.utils.utils import save_image
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
@ -212,15 +213,24 @@ class WanT2V:
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None:
white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
else:
white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = white_canvas
return ref_img.to(device)
canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
ref_img = canvas
canvas = None
if return_mask:
if outpainting_dims != None:
canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
else:
canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
canvas[:, :, top:top + new_height, left:left + new_width] = 0
canvas = canvas.to(device)
return ref_img.to(device), canvas
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None):
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = []
trim_video = len(keep_frames)
def conv_tensor(t, device):
@ -270,16 +280,22 @@ class WanT2V:
for k, frame in enumerate(inject_frames):
if frame != None:
src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims)
src_mask[i][:, k:k+1] = 0
src_video[i][:, k:k+1], src_mask[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
self.background_mask = None
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None and not torch.is_tensor(ref_img):
src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device)
if j==0 and any_background_ref:
if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
else:
src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device)
if self.background_mask != None:
self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
@ -408,11 +424,20 @@ class WanT2V:
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
previous_latents = None
# if overlapped_latents != None:
# input_ref_images = [u[-1:] for u in input_ref_images]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
if self.background_mask != None:
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(self.background_mask, None)
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
zz0[:, 0:1] = zzbg
mm0[:, 0:1] = mmbg
self.background_mask = zz0 = mm0 = zzbg = mmbg = None
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
@ -499,24 +524,14 @@ class WanT2V:
self.model.setup_chipmunk()
for i, t in enumerate(tqdm(timesteps)):
timestep = [t]
if overlapped_latents != None :
# overlap_noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000
overlap_noise_factor = overlap_noise / 1000
# overlap_noise_factor = (1000-t )/ 1000 # overlap_noise / 1000
# latent_noise_factor = 1 #max(min(1, (t - overlap_noise) / 1000 ),0)
latent_noise_factor = t / 1000
for zz, zz_r, ll in zip(z, z_reactive, [latents]):
pass
for zz, zz_r, ll in zip(z, z_reactive, [latents, None]): # extra None for second control net
zz[0:16, ref_images_count:overlapped_latents_size + ref_images_count] = zz_r[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(zz_r[:, ref_images_count:] ) * overlap_noise_factor
ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor
if conditioning_latents_size > 0 and overlap_noise > 0:
pass
overlap_noise_factor = overlap_noise / 1000
# latents[:, conditioning_latents_size + ref_images_count:] = latents[:, conditioning_latents_size + ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(latents[:, conditioning_latents_size + ref_images_count:]) * overlap_noise_factor
# timestep = [torch.tensor([t.item()] * (conditioning_latents_size + ref_images_count) + [t.item() - overlap_noise]*(target_shape[1] - conditioning_latents_size - ref_images_count))]
if ll != None:
ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor
if target_camera != None:
latent_model_input = torch.cat([latents, source_latents], dim=1)

27
wgp.py
View File

@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.4.9"
WanGP_version = "6.3"
WanGP_version = "6.31"
settings_version = 2
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -3351,6 +3351,7 @@ def generate_video(
original_image_refs = image_refs
frames_to_inject = []
any_background_ref = False
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
@ -3368,9 +3369,9 @@ def generate_video(
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
fit_canvas = None
if len(image_refs) > nb_frames_positions:
if hunyuan_avatar: remove_background_images_ref = 0
any_background_ref = remove_background_images_ref != 1
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
@ -3620,9 +3621,17 @@ def generate_video(
fit_into_canvas = sample_fit_canvas,
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
any_background_ref = any_background_ref
)
if len(frames_to_inject_parsed):
refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:]
if len(frames_to_inject_parsed) or any_background_ref:
new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
else:
new_image_refs += image_refs[nb_frames_positions:]
refresh_preview["image_refs"] = new_image_refs
new_image_refs = None
if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:]
sample_fit_canvas = None
@ -5315,9 +5324,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
("Use Vace raw format", "V"),
("Keep Unchanged", "UV"),
("Transfer Human Motion & Depth", "PDV"),
("Transfer Human Motion & Shape", "PSV"),
("Transfer Human Motion & Shapes", "PSV"),
("Transfer Human Motion & Flow", "PFV"),
("Transfer Depth & Shape", "DSV"),
("Transfer Depth & Shapes", "DSV"),
("Transfer Depth & Flow", "DFV"),
("Transfer Shapes & Flow", "SFV"),
],
@ -5392,7 +5401,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group():
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Background or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@ -5450,7 +5459,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
wizard_variables = "\n".join(variables)
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
with gr.Column(not advanced_prompt) as prompt_column_wizard:
with gr.Column(visible=not advanced_prompt) as prompt_column_wizard:
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label=wizard_prompt_label, value=default_wizard_prompt, lines=3)
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
wizard_variables_var = gr.Text(wizard_variables, visible = False)
@ -5759,7 +5768,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row,
video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs,
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row] # show_advanced presets_column,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced] # presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs