diff --git a/requirements.txt b/requirements.txt index 1d82bc9..a948cf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ numpy>=1.23.5,<2 einops moviepy==1.0.3 mmgp==3.4.9 -peft==0.14.0 +peft==0.15.0 mutagen pydantic==2.10.6 decord diff --git a/wan/modules/model.py b/wan/modules/model.py index fba8a67..edf04e2 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -1070,10 +1070,8 @@ class WanModel(ModelMixin, ConfigMixin): # Vace embeddings c = [self.vace_patch_embedding(u.to(self.vace_patch_embedding.weight.dtype).unsqueeze(0)) for u in vace_context] c = [u.flatten(2).transpose(1, 2) for u in c] - # c = c[0] - c = [ [sub_c] for sub_c in c] kwargs['context_scale'] = vace_context_scale - hints_list = [ c ]* len(x_list) + hints_list = [ [ [sub_c] for sub_c in c] for _ in range(len(x_list)) ] del c should_calc = True diff --git a/wan/text2video.py b/wan/text2video.py index 4020393..03ceef4 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -133,9 +133,10 @@ class WanT2V: reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] inactive = self.vae.encode(inactive, tile_size = tile_size) - if overlapped_latents != None : + if overlapped_latents != None and False : # inactive[0][:, 0:1] = self.vae.encode([frames[0][:, 0:1]], tile_size = tile_size)[0] # redundant - inactive[0][:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents + for t in inactive: + t[:, 1:overlapped_latents.shape[1] + 1] = overlapped_latents reactive = self.vae.encode(reactive, tile_size = tile_size) latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] @@ -162,7 +163,7 @@ class WanT2V: result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape - new_depth = int((depth + 3) // self.vae_stride[0]) + new_depth = int((depth + 3) // self.vae_stride[0]) # nb latents token without (ref tokens not included) height = 2 * (int(height) // (self.vae_stride[1] * 2)) width = 2 * (int(width) // (self.vae_stride[2] * 2)) @@ -189,7 +190,7 @@ class WanT2V: def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None): + def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): from wan.utils.utils import save_image ref_width, ref_height = ref_img.size if (ref_height, ref_width) == image_size and outpainting_dims == None: @@ -212,15 +213,24 @@ class WanT2V: ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) if outpainting_dims != None: - white_canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - white_canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img else: - white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img - ref_img = white_canvas - return ref_img.to(device) + canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + return ref_img.to(device), canvas - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): image_sizes = [] trim_video = len(keep_frames) def conv_tensor(t, device): @@ -270,16 +280,22 @@ class WanT2V: for k, frame in enumerate(inject_frames): if frame != None: - src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims) - src_mask[i][:, k:k+1] = 0 + src_video[i][:, k:k+1], src_mask[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) + self.background_mask = None for i, ref_images in enumerate(src_ref_images): if ref_images is not None: image_size = image_sizes[i] for j, ref_img in enumerate(ref_images): if ref_img is not None and not torch.is_tensor(ref_img): - src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device) + if j==0 and any_background_ref: + if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) + src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) + else: + src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device) + if self.background_mask != None: + self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref return src_video, src_mask, src_ref_images def decode_latent(self, zs, ref_images=None, tile_size= 0 ): @@ -408,11 +424,20 @@ class WanT2V: input_frames = [u.to(self.device) for u in input_frames] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_masks = [u.to(self.device) for u in input_masks] + if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] previous_latents = None # if overlapped_latents != None: # input_ref_images = [u[-1:] for u in input_ref_images] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) + if self.background_mask != None: + zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(self.background_mask, None) + for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): + zz0[:, 0:1] = zzbg + mm0[:, 0:1] = mmbg + + self.background_mask = zz0 = mm0 = zzbg = mmbg = None z = self.vace_latent(z0, m0) target_shape = list(z0[0].shape) @@ -499,24 +524,14 @@ class WanT2V: self.model.setup_chipmunk() for i, t in enumerate(tqdm(timesteps)): - timestep = [t] if overlapped_latents != None : - # overlap_noise_factor = overlap_noise *(i/(len(timesteps)-1)) / 1000 overlap_noise_factor = overlap_noise / 1000 - # overlap_noise_factor = (1000-t )/ 1000 # overlap_noise / 1000 - # latent_noise_factor = 1 #max(min(1, (t - overlap_noise) / 1000 ),0) latent_noise_factor = t / 1000 - for zz, zz_r, ll in zip(z, z_reactive, [latents]): - pass + for zz, zz_r, ll in zip(z, z_reactive, [latents, None]): # extra None for second control net zz[0:16, ref_images_count:overlapped_latents_size + ref_images_count] = zz_r[:, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(zz_r[:, ref_images_count:] ) * overlap_noise_factor - ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor - - if conditioning_latents_size > 0 and overlap_noise > 0: - pass - overlap_noise_factor = overlap_noise / 1000 - # latents[:, conditioning_latents_size + ref_images_count:] = latents[:, conditioning_latents_size + ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(latents[:, conditioning_latents_size + ref_images_count:]) * overlap_noise_factor - # timestep = [torch.tensor([t.item()] * (conditioning_latents_size + ref_images_count) + [t.item() - overlap_noise]*(target_shape[1] - conditioning_latents_size - ref_images_count))] + if ll != None: + ll[:, 0:overlapped_latents_size + ref_images_count] = zz_r * (1.0 - latent_noise_factor) + torch.randn_like(zz_r ) * latent_noise_factor if target_camera != None: latent_model_input = torch.cat([latents, source_latents], dim=1) diff --git a/wgp.py b/wgp.py index eb4d92b..e9e4f31 100644 --- a/wgp.py +++ b/wgp.py @@ -45,7 +45,7 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 target_mmgp_version = "3.4.9" -WanGP_version = "6.3" +WanGP_version = "6.31" settings_version = 2 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3351,6 +3351,7 @@ def generate_video( original_image_refs = image_refs frames_to_inject = [] + any_background_ref = False outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): @@ -3368,9 +3369,9 @@ def generate_video( h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) fit_canvas = None - if len(image_refs) > nb_frames_positions: if hunyuan_avatar: remove_background_images_ref = 0 + any_background_ref = remove_background_images_ref != 1 if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") @@ -3620,9 +3621,17 @@ def generate_video( fit_into_canvas = sample_fit_canvas, inject_frames= frames_to_inject_parsed, outpainting_dims = outpainting_dims, + any_background_ref = any_background_ref ) - if len(frames_to_inject_parsed): - refresh_preview["image_refs"] = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + image_refs[nb_frames_positions:] + if len(frames_to_inject_parsed) or any_background_ref: + new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + if any_background_ref: + new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + else: + new_image_refs += image_refs[nb_frames_positions:] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None + if sample_fit_canvas != None: image_size = src_video[0].shape[-2:] sample_fit_canvas = None @@ -5315,9 +5324,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ("Use Vace raw format", "V"), ("Keep Unchanged", "UV"), ("Transfer Human Motion & Depth", "PDV"), - ("Transfer Human Motion & Shape", "PSV"), + ("Transfer Human Motion & Shapes", "PSV"), ("Transfer Human Motion & Flow", "PFV"), - ("Transfer Depth & Shape", "DSV"), + ("Transfer Depth & Shapes", "DSV"), ("Transfer Depth & Flow", "DFV"), ("Transfer Shapes & Flow", "SFV"), ], @@ -5392,7 +5401,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Background or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -5450,7 +5459,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non wizard_variables = "\n".join(variables) for _ in range( PROMPT_VARS_MAX - len(prompt_vars)): prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False)) - with gr.Column(not advanced_prompt) as prompt_column_wizard: + with gr.Column(visible=not advanced_prompt) as prompt_column_wizard: wizard_prompt = gr.Textbox(visible = not advanced_prompt, label=wizard_prompt_label, value=default_wizard_prompt, lines=3) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) @@ -5759,7 +5768,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, - video_guide_outpainting_checkbox, video_guide_outpainting_row] # show_advanced presets_column, + video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced] # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs