mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	MatanyoneGP
This commit is contained in:
		
							parent
							
								
									6581d05916
								
							
						
					
					
						commit
						6b348b97e8
					
				@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
				
			|||||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
					**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 🔥 Latest Updates
 | 
					## 🔥 Latest Updates
 | 
				
			||||||
### July 21 2025: WanGP v7.1 
 | 
					### July 21 2025: WanGP v7.12
 | 
				
			||||||
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
 | 
					- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
 | 
					- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
 | 
				
			||||||
@ -30,9 +30,9 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
 | 
				
			|||||||
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
 | 
					- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
 | 
				
			||||||
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
 | 
					Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
And Also:
 | 
					- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
 | 
				
			||||||
- easier way to select video resolution 
 | 
					
 | 
				
			||||||
- started to optimize Matanyone to reduce VRAM requirements
 | 
					- Easier way to select video resolution 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
 | 
					### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
 | 
				
			||||||
 | 
				
			|||||||
@ -608,11 +608,11 @@ def query_model_def(model_type, model_def):
 | 
				
			|||||||
    LTXV_config = model_def.get("LTXV_config", "")
 | 
					    LTXV_config = model_def.get("LTXV_config", "")
 | 
				
			||||||
    distilled= "distilled" in LTXV_config 
 | 
					    distilled= "distilled" in LTXV_config 
 | 
				
			||||||
    model_def_output = {
 | 
					    model_def_output = {
 | 
				
			||||||
		"lock_inference_steps": True,
 | 
					 | 
				
			||||||
		"no_guidance": True,		
 | 
							"no_guidance": True,		
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    if distilled:
 | 
					    if distilled:
 | 
				
			||||||
        model_def_output.update({
 | 
					        model_def_output.update({
 | 
				
			||||||
 | 
							"lock_inference_steps": True,
 | 
				
			||||||
        "no_negative_prompt" : True,
 | 
					        "no_negative_prompt" : True,
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
				
			|||||||
@ -28,7 +28,9 @@ arg_mask_save = False
 | 
				
			|||||||
model_loaded = False
 | 
					model_loaded = False
 | 
				
			||||||
model = None
 | 
					model = None
 | 
				
			||||||
matanyone_model = None
 | 
					matanyone_model = None
 | 
				
			||||||
 | 
					model_in_GPU = False
 | 
				
			||||||
 | 
					matanyone_in_GPU = False
 | 
				
			||||||
 | 
					bfloat16_supported = False
 | 
				
			||||||
# SAM generator
 | 
					# SAM generator
 | 
				
			||||||
class MaskGenerator():
 | 
					class MaskGenerator():
 | 
				
			||||||
    def __init__(self, sam_checkpoint, device):
 | 
					    def __init__(self, sam_checkpoint, device):
 | 
				
			||||||
@ -66,7 +68,6 @@ def get_frames_from_image(image_input, image_state):
 | 
				
			|||||||
    Return 
 | 
					    Return 
 | 
				
			||||||
        [[0:nearest_frame], [nearest_frame:], nearest_frame]
 | 
					        [[0:nearest_frame], [nearest_frame:], nearest_frame]
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    load_sam()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    user_name = time.time()
 | 
					    user_name = time.time()
 | 
				
			||||||
    frames = [image_input] * 2  # hardcode: mimic a video with 2 frames
 | 
					    frames = [image_input] * 2  # hardcode: mimic a video with 2 frames
 | 
				
			||||||
@ -85,7 +86,7 @@ def get_frames_from_image(image_input, image_state):
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
 | 
					    image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
 | 
				
			||||||
    set_image_encoder_patch()
 | 
					    set_image_encoder_patch()
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    select_SAM()
 | 
				
			||||||
    model.samcontroler.sam_controler.reset_image() 
 | 
					    model.samcontroler.sam_controler.reset_image() 
 | 
				
			||||||
    model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
 | 
					    model.samcontroler.sam_controler.set_image(image_state["origin_images"][0])
 | 
				
			||||||
    torch.cuda.empty_cache()
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
@ -108,7 +109,6 @@ def get_frames_from_video(video_input, video_state):
 | 
				
			|||||||
        [[0:nearest_frame], [nearest_frame:], nearest_frame]
 | 
					        [[0:nearest_frame], [nearest_frame:], nearest_frame]
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    load_sam()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    while model == None:
 | 
					    while model == None:
 | 
				
			||||||
        time.sleep(1)
 | 
					        time.sleep(1)
 | 
				
			||||||
@ -168,7 +168,7 @@ def get_frames_from_video(video_input, video_state):
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
 | 
					    video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size)
 | 
				
			||||||
    set_image_encoder_patch()
 | 
					    set_image_encoder_patch()
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    select_SAM()
 | 
				
			||||||
    model.samcontroler.sam_controler.reset_image() 
 | 
					    model.samcontroler.sam_controler.reset_image() 
 | 
				
			||||||
    model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
 | 
					    model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    torch.cuda.empty_cache()    
 | 
				
			||||||
@ -237,18 +237,37 @@ def patched_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			|||||||
        attn += rel_w[:, :, :, None, :]
 | 
					        attn += rel_w[:, :, :, None, :]
 | 
				
			||||||
        return attn.view(B, q_h * q_w, k_h * k_w)
 | 
					        return attn.view(B, q_h * q_w, k_h * k_w)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def pay_attention(self, x: torch.Tensor) -> torch.Tensor:
 | 
					    def pay_attention(self, x: torch.Tensor, split_heads = 1) -> torch.Tensor:
 | 
				
			||||||
            B, H, W, _ = x.shape
 | 
					            B, H, W, _ = x.shape
 | 
				
			||||||
            # qkv with shape (3, B, nHead, H * W, C)
 | 
					            # qkv with shape (3, B, nHead, H * W, C)
 | 
				
			||||||
            qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
 | 
					            qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not bfloat16_supported: qkv = qkv.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # q, k, v with shape (B * nHead, H * W, C)
 | 
					            # q, k, v with shape (B * nHead, H * W, C)
 | 
				
			||||||
            q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
 | 
					            q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
 | 
				
			||||||
            attn_mask = None
 | 
					            if split_heads == 1:
 | 
				
			||||||
            if self.use_rel_pos:
 | 
					                attn_mask = None
 | 
				
			||||||
                attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
 | 
					                if self.use_rel_pos:
 | 
				
			||||||
            x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
 | 
					                    attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W))
 | 
				
			||||||
 | 
					                x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                chunk_size = self.num_heads // split_heads 
 | 
				
			||||||
 | 
					                x = torch.empty_like(q)
 | 
				
			||||||
 | 
					                q_chunks = torch.split(q, chunk_size)
 | 
				
			||||||
 | 
					                k_chunks = torch.split(k, chunk_size)
 | 
				
			||||||
 | 
					                v_chunks = torch.split(v, chunk_size)
 | 
				
			||||||
 | 
					                x_chunks = torch.split(x, chunk_size)
 | 
				
			||||||
 | 
					                for x_chunk, q_chunk, k_chunk, v_chunk  in zip(x_chunks, q_chunks, k_chunks, v_chunks):
 | 
				
			||||||
 | 
					                    attn_mask = None
 | 
				
			||||||
 | 
					                    if self.use_rel_pos:
 | 
				
			||||||
 | 
					                        attn_mask = get_decomposed_rel_pos(q_chunk, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W))
 | 
				
			||||||
 | 
					                    x_chunk[...]  = F.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk, attn_mask=attn_mask, scale=self.scale)
 | 
				
			||||||
 | 
					                del x_chunk, q_chunk, k_chunk, v_chunk
 | 
				
			||||||
            del q, k, v, attn_mask
 | 
					            del q, k, v, attn_mask
 | 
				
			||||||
            x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
 | 
					            x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
 | 
				
			||||||
 | 
					            if not bfloat16_supported: x = x.to(torch.bfloat16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            return self.proj(x)
 | 
					            return self.proj(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    shortcut = x
 | 
					    shortcut = x
 | 
				
			||||||
@ -257,8 +276,17 @@ def patched_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			|||||||
    if self.window_size > 0:
 | 
					    if self.window_size > 0:
 | 
				
			||||||
        H, W = x.shape[1], x.shape[2]
 | 
					        H, W = x.shape[1], x.shape[2]
 | 
				
			||||||
        x, pad_hw = window_partition(x, self.window_size)
 | 
					        x, pad_hw = window_partition(x, self.window_size)
 | 
				
			||||||
 | 
					    x_shape = x.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if x_shape[0] > 10:
 | 
				
			||||||
 | 
					        chunk_size = int(x.shape[0]/4) + 1
 | 
				
			||||||
 | 
					        x_chunks = torch.split(x, chunk_size)
 | 
				
			||||||
 | 
					        for i, x_chunk  in enumerate(x_chunks):
 | 
				
			||||||
 | 
					            x_chunk[...] = pay_attention(self.attn,x_chunk)  
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        x = pay_attention(self.attn,x, 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    x = pay_attention(self.attn,x)
 | 
					 | 
				
			||||||
    # Reverse window partition
 | 
					    # Reverse window partition
 | 
				
			||||||
    if self.window_size > 0:
 | 
					    if self.window_size > 0:
 | 
				
			||||||
        x = window_unpartition(x, self.window_size, pad_hw, (H, W))
 | 
					        x = window_unpartition(x, self.window_size, pad_hw, (H, W))
 | 
				
			||||||
@ -270,7 +298,7 @@ def patched_forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
				
			|||||||
    return x
 | 
					    return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def set_image_encoder_patch():
 | 
					def set_image_encoder_patch():
 | 
				
			||||||
    if not hasattr(image_encoder_block, "patched"):
 | 
					    if not hasattr(image_encoder_block, "patched"):  #and False
 | 
				
			||||||
        image_encoder_block.forward = patched_forward
 | 
					        image_encoder_block.forward = patched_forward
 | 
				
			||||||
        image_encoder_block.patched = True
 | 
					        image_encoder_block.patched = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -289,11 +317,12 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
 | 
				
			|||||||
        coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
 | 
					        coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
 | 
				
			||||||
        interactive_state["negative_click_times"] += 1
 | 
					        interactive_state["negative_click_times"] += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    select_SAM()
 | 
				
			||||||
    # prompt for sam model
 | 
					    # prompt for sam model
 | 
				
			||||||
    set_image_encoder_patch()
 | 
					    set_image_encoder_patch()
 | 
				
			||||||
    model.samcontroler.sam_controler.reset_image()
 | 
					    model.samcontroler.sam_controler.reset_image()
 | 
				
			||||||
    model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
 | 
					    model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
    prompt = get_prompt(click_state=click_state, click_input=coordinate)
 | 
					    prompt = get_prompt(click_state=click_state, click_input=coordinate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mask, logit, painted_image = model.first_frame_click( 
 | 
					    mask, logit, painted_image = model.first_frame_click( 
 | 
				
			||||||
@ -387,7 +416,7 @@ def image_matting(video_state, interactive_state, mask_dropdown, erode_kernel_si
 | 
				
			|||||||
    # operation error
 | 
					    # operation error
 | 
				
			||||||
    if len(np.unique(template_mask))==1:
 | 
					    if len(np.unique(template_mask))==1:
 | 
				
			||||||
        template_mask[0][0]=1
 | 
					        template_mask[0][0]=1
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    select_matanyone()
 | 
				
			||||||
    foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
 | 
					    foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size, n_warmup=refine_iter)
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    torch.cuda.empty_cache()    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -452,19 +481,25 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask
 | 
				
			|||||||
    # operation error
 | 
					    # operation error
 | 
				
			||||||
    if len(np.unique(template_mask))==1:
 | 
					    if len(np.unique(template_mask))==1:
 | 
				
			||||||
        template_mask[0][0]=1
 | 
					        template_mask[0][0]=1
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    select_matanyone()
 | 
				
			||||||
    foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
 | 
					    foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size)
 | 
				
			||||||
    torch.cuda.empty_cache()    
 | 
					    torch.cuda.empty_cache()    
 | 
				
			||||||
    output_frames = []
 | 
					    output_frames = []
 | 
				
			||||||
    foreground_mat = matting_type == "Foreground"
 | 
					    foreground_mat = matting_type == "Foreground"
 | 
				
			||||||
 | 
					    new_alpha = []
 | 
				
			||||||
    if not foreground_mat:
 | 
					    if not foreground_mat:
 | 
				
			||||||
        new_alpha = []
 | 
					 | 
				
			||||||
        for frame_alpha in alpha:
 | 
					        for frame_alpha in alpha:
 | 
				
			||||||
            frame_temp = frame_alpha.copy()
 | 
					            frame_temp = frame_alpha.copy()
 | 
				
			||||||
            frame_alpha[frame_temp > 127] = 0
 | 
					            frame_alpha[frame_temp > 127] = 0
 | 
				
			||||||
            frame_alpha[frame_temp <= 127] = 255
 | 
					            frame_alpha[frame_temp <= 127] = 255
 | 
				
			||||||
            new_alpha.append(frame_alpha)
 | 
					            new_alpha.append(frame_alpha)
 | 
				
			||||||
        alpha = new_alpha
 | 
					    else:
 | 
				
			||||||
 | 
					        for frame_alpha in alpha:
 | 
				
			||||||
 | 
					            frame_alpha[frame_alpha > 127] = 255
 | 
				
			||||||
 | 
					            frame_alpha[frame_alpha <= 127] = 0
 | 
				
			||||||
 | 
					            new_alpha.append(frame_alpha)
 | 
				
			||||||
 | 
					    alpha = new_alpha
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # for frame_origin, frame_alpha in zip(following_frames, alpha):
 | 
					    # for frame_origin, frame_alpha in zip(following_frames, alpha):
 | 
				
			||||||
    #     if foreground_mat:
 | 
					    #     if foreground_mat:
 | 
				
			||||||
    #         frame_alpha[frame_alpha > 127] = 255
 | 
					    #         frame_alpha[frame_alpha > 127] = 255
 | 
				
			||||||
@ -572,21 +607,42 @@ def restart():
 | 
				
			|||||||
        gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
 | 
					        gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
 | 
				
			||||||
        gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
 | 
					        gr.update(visible=False), gr.update(visible=False, choices=[], value=[]), "", gr.update(visible=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_sam():
 | 
					# def load_sam():
 | 
				
			||||||
    global model_loaded
 | 
					#     global model_loaded
 | 
				
			||||||
    global model
 | 
					#     global model
 | 
				
			||||||
    global matanyone_model 
 | 
					#     model.samcontroler.sam_controler.model.to(arg_device)
 | 
				
			||||||
    model.samcontroler.sam_controler.model.to(arg_device)
 | 
					
 | 
				
			||||||
 | 
					#     global matanyone_model 
 | 
				
			||||||
 | 
					#     matanyone_model.to(arg_device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def select_matanyone():
 | 
				
			||||||
 | 
					    global matanyone_in_GPU, model_in_GPU 
 | 
				
			||||||
 | 
					    if matanyone_in_GPU: return
 | 
				
			||||||
 | 
					    model.samcontroler.sam_controler.model.to("cpu")
 | 
				
			||||||
 | 
					    model_in_GPU = False
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
    matanyone_model.to(arg_device)
 | 
					    matanyone_model.to(arg_device)
 | 
				
			||||||
 | 
					    matanyone_in_GPU = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def select_SAM():
 | 
				
			||||||
 | 
					    global matanyone_in_GPU, model_in_GPU 
 | 
				
			||||||
 | 
					    if model_in_GPU: return
 | 
				
			||||||
 | 
					    matanyone_model.to("cpu")
 | 
				
			||||||
 | 
					    matanyone_in_GPU = False
 | 
				
			||||||
 | 
					    torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					    model.samcontroler.sam_controler.model.to(arg_device)
 | 
				
			||||||
 | 
					    model_in_GPU = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def load_unload_models(selected):
 | 
					def load_unload_models(selected):
 | 
				
			||||||
    global model_loaded
 | 
					    global model_loaded
 | 
				
			||||||
    global model
 | 
					    global model
 | 
				
			||||||
    global matanyone_model 
 | 
					    global matanyone_model, matanyone_processor, matanyone_in_GPU , model_in_GPU, bfloat16_supported
 | 
				
			||||||
    if selected:
 | 
					    if selected:
 | 
				
			||||||
        # print("Matanyone Tab Selected")
 | 
					        # print("Matanyone Tab Selected")
 | 
				
			||||||
        if model_loaded:
 | 
					        if model_loaded:
 | 
				
			||||||
            load_sam()
 | 
					            pass
 | 
				
			||||||
 | 
					            # load_sam()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # args, defined in track_anything.py
 | 
					            # args, defined in track_anything.py
 | 
				
			||||||
            sam_checkpoint_url_dict = {
 | 
					            sam_checkpoint_url_dict = {
 | 
				
			||||||
@ -604,21 +660,33 @@ def load_unload_models(selected):
 | 
				
			|||||||
            transfer_stream = torch.cuda.Stream()
 | 
					            transfer_stream = torch.cuda.Stream()
 | 
				
			||||||
            with torch.cuda.stream(transfer_stream):
 | 
					            with torch.cuda.stream(transfer_stream):
 | 
				
			||||||
                # initialize sams
 | 
					                # initialize sams
 | 
				
			||||||
                model = MaskGenerator(sam_checkpoint, arg_device)
 | 
					                major, minor = torch.cuda.get_device_capability(arg_device)
 | 
				
			||||||
 | 
					                if  major < 8:
 | 
				
			||||||
 | 
					                    bfloat16_supported = False
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    bfloat16_supported = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                model = MaskGenerator(sam_checkpoint, "cpu")
 | 
				
			||||||
 | 
					                model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device)
 | 
				
			||||||
 | 
					                model_in_GPU = True
 | 
				
			||||||
                from .matanyone.model.matanyone import MatAnyone
 | 
					                from .matanyone.model.matanyone import MatAnyone
 | 
				
			||||||
                matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
 | 
					                matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
 | 
				
			||||||
                # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
 | 
					                # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
 | 
				
			||||||
                # offload.profile(pipe)
 | 
					                # offload.profile(pipe)
 | 
				
			||||||
                matanyone_model = matanyone_model.to(arg_device).eval()
 | 
					                matanyone_model = matanyone_model.to("cpu").eval()
 | 
				
			||||||
 | 
					                matanyone_in_GPU = False
 | 
				
			||||||
                matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
 | 
					                matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg)
 | 
				
			||||||
            model_loaded  = True
 | 
					            model_loaded  = True
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # print("Matanyone Tab UnSelected")
 | 
					        # print("Matanyone Tab UnSelected")
 | 
				
			||||||
        import gc
 | 
					        import gc
 | 
				
			||||||
        model.samcontroler.sam_controler.model.to("cpu")
 | 
					        # model.samcontroler.sam_controler.model.to("cpu")
 | 
				
			||||||
        matanyone_model.to("cpu")
 | 
					        # matanyone_model.to("cpu")
 | 
				
			||||||
 | 
					        model = matanyone_model = matanyone_processor = None
 | 
				
			||||||
 | 
					        matanyone_in_GPU = model_in_GPU = False
 | 
				
			||||||
        gc.collect()
 | 
					        gc.collect()
 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					        torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					        model_loaded = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_vmc_event_handler():
 | 
					def get_vmc_event_handler():
 | 
				
			||||||
@ -663,10 +731,11 @@ def display(tabs, tab_state, model_choice, vace_video_input, vace_image_input, v
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # download assets
 | 
					    # download assets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    gr.Markdown("<B>Mast Edition is provided by MatAnyone</B>")
 | 
					    gr.Markdown("<B>Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep</B>")
 | 
				
			||||||
    gr.Markdown("If you have some trouble creating the perfect mask, be aware of these tips:")
 | 
					    gr.Markdown("If you have some trouble creating the perfect mask, be aware of these tips:")
 | 
				
			||||||
    gr.Markdown("- Using the Matanyone Settings you can also define Negative Point Prompts to remove parts of the current selection.")
 | 
					    gr.Markdown("- Using the Matanyone Settings you can also define Negative Point Prompts to remove parts of the current selection.")
 | 
				
			||||||
    gr.Markdown("- Sometime it is very hard to fit everything you want in a single mask, it may be much easier to combine multiple independent sub Masks before producing the Matting : each sub Mask is created by selecting an  area of an image and by clicking the Add Mask button. Sub masks can then be enabled / disabled in the Matanyone settings.")
 | 
					    gr.Markdown("- Sometime it is very hard to fit everything you want in a single mask, it may be much easier to combine multiple independent sub Masks before producing the Matting : each sub Mask is created by selecting an  area of an image and by clicking the Add Mask button. Sub masks can then be enabled / disabled in the Matanyone settings.")
 | 
				
			||||||
 | 
					    gr.Markdown("The Mask Generation time and the VRAM consumed are proportional to the number of frames and the resolution. So if relevant, you may reduce the number of frames in the Matanyone Settings. You will need for the moment to resize yourself the video if needed.")
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    with gr.Column( visible=True):
 | 
					    with gr.Column( visible=True):
 | 
				
			||||||
        with gr.Row():
 | 
					        with gr.Row():
 | 
				
			||||||
 | 
				
			|||||||
@ -34,24 +34,38 @@ def get_similarity(mk: torch.Tensor,
 | 
				
			|||||||
        uncert_mask = uncert_mask.expand(-1, 64, -1)
 | 
					        uncert_mask = uncert_mask.expand(-1, 64, -1)
 | 
				
			||||||
        qk = qk * uncert_mask
 | 
					        qk = qk * uncert_mask
 | 
				
			||||||
        qe = qe * uncert_mask
 | 
					        qe = qe * uncert_mask
 | 
				
			||||||
 | 
					    # Behold the work of DeeBeepMeep the Code Butcher !
 | 
				
			||||||
    if qe is not None:
 | 
					    if qe is not None:
 | 
				
			||||||
        # See XMem's appendix for derivation
 | 
					        # See XMem's appendix for derivation
 | 
				
			||||||
        mk = mk.transpose(1, 2)
 | 
					        mk = mk.transpose(1, 2)
 | 
				
			||||||
        a_sq = (mk.pow(2) @ qe)
 | 
					        a_sq = (mk.pow(2) @ qe)
 | 
				
			||||||
        two_ab = 2 * (mk @ (qk * qe))
 | 
					        two_ab =  mk @ (qk * qe)
 | 
				
			||||||
 | 
					        two_ab *= 2
 | 
				
			||||||
 | 
					        two_ab.sub_(a_sq)
 | 
				
			||||||
 | 
					        del a_sq
 | 
				
			||||||
        b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
 | 
					        b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
 | 
				
			||||||
        similarity = (-a_sq + two_ab - b_sq)
 | 
					        two_ab.sub_(b_sq)
 | 
				
			||||||
 | 
					        similarity = two_ab
 | 
				
			||||||
 | 
					        del b_sq, two_ab 
 | 
				
			||||||
 | 
					        # similarity = (-a_sq + two_ab - b_sq)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        # similar to STCN if we don't have the selection term
 | 
					        # similar to STCN if we don't have the selection term
 | 
				
			||||||
        a_sq = mk.pow(2).sum(1).unsqueeze(2)
 | 
					        a_sq = mk.pow(2).sum(1).unsqueeze(2)
 | 
				
			||||||
        two_ab = 2 * (mk.transpose(1, 2) @ qk)
 | 
					        two_ab = mk.transpose(1, 2) @ qk
 | 
				
			||||||
        similarity = (-a_sq + two_ab)
 | 
					        two_ab *= 2
 | 
				
			||||||
 | 
					        two_ab.sub_(a_sq)
 | 
				
			||||||
 | 
					        del a_sq
 | 
				
			||||||
 | 
					        similarity = two_ab
 | 
				
			||||||
 | 
					        del two_ab 
 | 
				
			||||||
 | 
					        # similarity = (-a_sq + two_ab)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if ms is not None:
 | 
					    if ms is not None:
 | 
				
			||||||
        similarity = similarity * ms / math.sqrt(CK)  # B*N*HW
 | 
					        similarity *= ms
 | 
				
			||||||
 | 
					        similarity /=  math.sqrt(CK)
 | 
				
			||||||
 | 
					        # similarity = similarity * ms / math.sqrt(CK)  # B*N*HW
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        similarity = similarity / math.sqrt(CK)  # B*N*HW
 | 
					        similarity /=  math.sqrt(CK)
 | 
				
			||||||
 | 
					        # similarity = similarity / math.sqrt(CK)  # B*N*HW
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return similarity
 | 
					    return similarity
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -47,9 +47,13 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    frames = []
 | 
					    frames = []
 | 
				
			||||||
    phas = []
 | 
					    phas = []
 | 
				
			||||||
 | 
					    i = 0
 | 
				
			||||||
    for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
 | 
					    for ti, frame_single in tqdm.tqdm(enumerate(frames_np)):
 | 
				
			||||||
        image = to_tensor(frame_single).cuda().float()
 | 
					        image = to_tensor(frame_single).cuda().float()
 | 
				
			||||||
 | 
					        if i % 10 ==0:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					            # torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					        i += 1
 | 
				
			||||||
        if ti == 0:
 | 
					        if ti == 0:
 | 
				
			||||||
            output_prob = processor.step(image, mask, objects=objects)      # encode given mask
 | 
					            output_prob = processor.step(image, mask, objects=objects)      # encode given mask
 | 
				
			||||||
            output_prob = processor.step(image, first_frame_pred=True)      # clear past memory for warmup frames
 | 
					            output_prob = processor.step(image, first_frame_pred=True)      # clear past memory for warmup frames
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										6
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								wgp.py
									
									
									
									
									
								
							@ -51,7 +51,7 @@ AUTOSAVE_FILENAME = "queue.zip"
 | 
				
			|||||||
PROMPT_VARS_MAX = 10
 | 
					PROMPT_VARS_MAX = 10
 | 
				
			||||||
 | 
					
 | 
				
			||||||
target_mmgp_version = "3.5.1"
 | 
					target_mmgp_version = "3.5.1"
 | 
				
			||||||
WanGP_version = "7.11"
 | 
					WanGP_version = "7.12"
 | 
				
			||||||
settings_version = 2.22
 | 
					settings_version = 2.22
 | 
				
			||||||
max_source_video_frames = 3000
 | 
					max_source_video_frames = 3000
 | 
				
			||||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
					prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
 | 
				
			||||||
@ -275,6 +275,10 @@ def process_prompt_and_add_tasks(state, model_choice):
 | 
				
			|||||||
    skip_steps_cache_type= inputs["skip_steps_cache_type"]
 | 
					    skip_steps_cache_type= inputs["skip_steps_cache_type"]
 | 
				
			||||||
    MMAudio_setting = inputs["MMAudio_setting"]
 | 
					    MMAudio_setting = inputs["MMAudio_setting"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
 | 
				
			||||||
 | 
					        gr.Info("The minimum number of steps should be 20") 
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
    if skip_steps_cache_type == "mag":
 | 
					    if skip_steps_cache_type == "mag":
 | 
				
			||||||
        if model_type in  ["sky_df_1.3B", "sky_df_14B"]:
 | 
					        if model_type in  ["sky_df_1.3B", "sky_df_14B"]:
 | 
				
			||||||
            gr.Info("Mag Cache is not supported with Diffusion Forcing")
 | 
					            gr.Info("Mag Cache is not supported with Diffusion Forcing")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user