mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	intermediate commit
This commit is contained in:
		
							parent
							
								
									99fd9aea32
								
							
						
					
					
						commit
						f9f63cbc79
					
				@ -7,8 +7,6 @@
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
 | 
			
		||||
            "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
 | 
			
		||||
        ],
 | 
			
		||||
		"image_outputs": true,		
 | 
			
		||||
		"reference_image": true,		
 | 
			
		||||
		"flux-model": "flux-dev-kontext"		
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "add a hat",
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,6 @@
 | 
			
		||||
		"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
 | 
			
		||||
        "URLs": "flux",
 | 
			
		||||
		"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
 | 
			
		||||
		"image_outputs": true,		
 | 
			
		||||
		"reference_image": true,		
 | 
			
		||||
		"flux-model": "flux-dev-uso"		
 | 
			
		||||
    },
 | 
			
		||||
	"prompt": "the man is wearing a hat",
 | 
			
		||||
 | 
			
		||||
@ -9,9 +9,7 @@
 | 
			
		||||
        ],
 | 
			
		||||
        "attention": {
 | 
			
		||||
            "<89": "sdpa"
 | 
			
		||||
        },
 | 
			
		||||
        "reference_image": true,
 | 
			
		||||
        "image_outputs": true
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    "prompt": "add a hat",
 | 
			
		||||
    "resolution": "1280x720",
 | 
			
		||||
 | 
			
		||||
@ -13,28 +13,41 @@ class family_handler():
 | 
			
		||||
        flux_schnell = flux_model == "flux-schnell" 
 | 
			
		||||
        flux_chroma = flux_model == "flux-chroma" 
 | 
			
		||||
        flux_uso = flux_model == "flux-dev-uso"
 | 
			
		||||
        model_def_output = {
 | 
			
		||||
        flux_kontext = flux_model == "flux-dev-kontext"
 | 
			
		||||
        
 | 
			
		||||
        extra_model_def = {
 | 
			
		||||
            "image_outputs" : True,
 | 
			
		||||
            "no_negative_prompt" : not flux_chroma,
 | 
			
		||||
        }
 | 
			
		||||
        if flux_chroma:
 | 
			
		||||
            model_def_output["guidance_max_phases"] = 1
 | 
			
		||||
            extra_model_def["guidance_max_phases"] = 1
 | 
			
		||||
        elif not flux_schnell:
 | 
			
		||||
            model_def_output["embedded_guidance"] = True
 | 
			
		||||
            extra_model_def["embedded_guidance"] = True
 | 
			
		||||
        if flux_uso :
 | 
			
		||||
            model_def_output["any_image_refs_relative_size"] = True
 | 
			
		||||
            model_def_output["no_background_removal"] = True
 | 
			
		||||
 | 
			
		||||
            model_def_output["image_ref_choices"] = {
 | 
			
		||||
            extra_model_def["any_image_refs_relative_size"] = True
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
 | 
			
		||||
                            ("Up to two Images are Style Images", "KIJ")],
 | 
			
		||||
                "default": "KI",
 | 
			
		||||
                "letters_filter": "KIJ",
 | 
			
		||||
                "label": "Reference Images / Style Images"
 | 
			
		||||
            }
 | 
			
		||||
        model_def_output["lock_image_refs_ratios"] = True
 | 
			
		||||
        
 | 
			
		||||
        if flux_kontext:
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("None", ""),
 | 
			
		||||
                    ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
 | 
			
		||||
                    ("Conditional Images are People / Objects", "I"),
 | 
			
		||||
                    ],
 | 
			
		||||
                "letters_filter": "KI",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return model_def_output
 | 
			
		||||
 | 
			
		||||
        extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_supported_types():
 | 
			
		||||
@ -122,10 +135,12 @@ class family_handler():
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        flux_model = model_def.get("flux-model", "flux-dev")
 | 
			
		||||
        flux_uso = flux_model == "flux-dev-uso"
 | 
			
		||||
        flux_kontext = flux_model == "flux-dev-kontext"
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
            "embedded_guidance":  2.5,
 | 
			
		||||
        })            
 | 
			
		||||
        if model_def.get("reference_image", False):
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        if flux_kontext or flux_uso:
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "KI",
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
@ -24,44 +24,6 @@ from .util import (
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
 | 
			
		||||
    target_height_ref1 = int(target_height_ref1 // 64 * 64)
 | 
			
		||||
    target_width_ref1 = int(target_width_ref1 // 64 * 64)
 | 
			
		||||
    h, w = image.shape[-2:]
 | 
			
		||||
    if h < target_height_ref1 or w < target_width_ref1:
 | 
			
		||||
        # 计算长宽比
 | 
			
		||||
        aspect_ratio = w / h
 | 
			
		||||
        if h < target_height_ref1:
 | 
			
		||||
            new_h = target_height_ref1
 | 
			
		||||
            new_w = new_h * aspect_ratio
 | 
			
		||||
            if new_w < target_width_ref1:
 | 
			
		||||
                new_w = target_width_ref1
 | 
			
		||||
                new_h = new_w / aspect_ratio
 | 
			
		||||
        else:
 | 
			
		||||
            new_w = target_width_ref1
 | 
			
		||||
            new_h = new_w / aspect_ratio
 | 
			
		||||
            if new_h < target_height_ref1:
 | 
			
		||||
                new_h = target_height_ref1
 | 
			
		||||
                new_w = new_h * aspect_ratio
 | 
			
		||||
    else:
 | 
			
		||||
        aspect_ratio = w / h
 | 
			
		||||
        tgt_aspect_ratio = target_width_ref1 / target_height_ref1
 | 
			
		||||
        if aspect_ratio > tgt_aspect_ratio:
 | 
			
		||||
            new_h = target_height_ref1
 | 
			
		||||
            new_w = new_h * aspect_ratio
 | 
			
		||||
        else:
 | 
			
		||||
            new_w = target_width_ref1
 | 
			
		||||
            new_h = new_w / aspect_ratio
 | 
			
		||||
    # 使用 TVF.resize 进行图像缩放
 | 
			
		||||
    image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
 | 
			
		||||
    # 计算中心裁剪的参数
 | 
			
		||||
    top = (image.shape[-2] - target_height_ref1) // 2
 | 
			
		||||
    left = (image.shape[-1] - target_width_ref1) // 2
 | 
			
		||||
    # 使用 TVF.crop 进行中心裁剪
 | 
			
		||||
    image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
 | 
			
		||||
    return image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def stitch_images(img1, img2):
 | 
			
		||||
    # Resize img2 to match img1's height
 | 
			
		||||
    width1, height1 = img1.size
 | 
			
		||||
@ -171,8 +133,6 @@ class model_factory:
 | 
			
		||||
            device="cuda"
 | 
			
		||||
            flux_dev_uso = self.name in ['flux-dev-uso']
 | 
			
		||||
            image_stiching =  not self.name in ['flux-dev-uso'] #and False
 | 
			
		||||
            # image_refs_relative_size = 100
 | 
			
		||||
            crop = False
 | 
			
		||||
            input_ref_images = [] if input_ref_images is None else input_ref_images[:]
 | 
			
		||||
            ref_style_imgs = []
 | 
			
		||||
            if "I" in video_prompt_type and len(input_ref_images) > 0: 
 | 
			
		||||
@ -186,36 +146,15 @@ class model_factory:
 | 
			
		||||
                if image_stiching:
 | 
			
		||||
                    # image stiching method
 | 
			
		||||
                    stiched = input_ref_images[0]
 | 
			
		||||
                    if "K" in video_prompt_type :
 | 
			
		||||
                        w, h = input_ref_images[0].size
 | 
			
		||||
                        height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
			
		||||
                        # actual rescale will happen in prepare_kontext
 | 
			
		||||
                    for new_img in input_ref_images[1:]:
 | 
			
		||||
                        stiched = stitch_images(stiched, new_img)
 | 
			
		||||
                    input_ref_images  = [stiched]
 | 
			
		||||
                else:
 | 
			
		||||
                    first_ref = 0
 | 
			
		||||
                    if "K" in video_prompt_type:
 | 
			
		||||
                        # image latents tiling method
 | 
			
		||||
                        w, h = input_ref_images[0].size
 | 
			
		||||
                        if crop :
 | 
			
		||||
                            img = convert_image_to_tensor(input_ref_images[0])
 | 
			
		||||
                            img = resize_and_centercrop_image(img, height, width)                       
 | 
			
		||||
                            input_ref_images[0] = convert_tensor_to_image(img)                    
 | 
			
		||||
                        else:
 | 
			
		||||
                            height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
 | 
			
		||||
                            input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                        first_ref = 1
 | 
			
		||||
 | 
			
		||||
                    for i in range(first_ref,len(input_ref_images)):
 | 
			
		||||
                    # latents stiching with resize 
 | 
			
		||||
                    for i in range(len(input_ref_images)):
 | 
			
		||||
                        w, h = input_ref_images[i].size
 | 
			
		||||
                        if crop:
 | 
			
		||||
                            img = convert_image_to_tensor(input_ref_images[i])
 | 
			
		||||
                            img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100)) 
 | 
			
		||||
                            input_ref_images[i] = convert_tensor_to_image(img)                    
 | 
			
		||||
                        else:
 | 
			
		||||
                            image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
 | 
			
		||||
                            input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                        image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
 | 
			
		||||
                        input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            else:
 | 
			
		||||
                input_ref_images = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -861,11 +861,6 @@ class HunyuanVideoSampler(Inference):
 | 
			
		||||
            freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
 | 
			
		||||
        else:
 | 
			
		||||
            if self.avatar:
 | 
			
		||||
                w, h = input_ref_images.size
 | 
			
		||||
                target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
 | 
			
		||||
                if target_width != w or target_height != h:
 | 
			
		||||
                    input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
 | 
			
		||||
                concat_dict = {'mode': 'timecat', 'bias': -1} 
 | 
			
		||||
                freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,23 @@ class family_handler():
 | 
			
		||||
        extra_model_def["tea_cache"] = True
 | 
			
		||||
        extra_model_def["mag_cache"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom_edit"]:
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                "selection": ["MV", "PMV"],
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                "selection": ["A", "NA"],
 | 
			
		||||
                "default" : "NA"
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [("Reference Image", "I")],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
                "visible": False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
 | 
			
		||||
@ -141,6 +158,10 @@ class family_handler():
 | 
			
		||||
 | 
			
		||||
        return hunyuan_model, pipe
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
 | 
			
		||||
        pass
 | 
			
		||||
    
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults["embedded_guidance_scale"]= 6.0
 | 
			
		||||
 | 
			
		||||
@ -300,9 +300,6 @@ class LTXV:
 | 
			
		||||
            prefix_size, height, width = input_video.shape[-3:]
 | 
			
		||||
        else:
 | 
			
		||||
            if image_start != None:
 | 
			
		||||
                frame_width, frame_height  = image_start.size
 | 
			
		||||
                if fit_into_canvas != None:
 | 
			
		||||
                    height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
 | 
			
		||||
                conditioning_media_paths.append(image_start.unsqueeze(1)) 
 | 
			
		||||
                conditioning_start_frames.append(0)
 | 
			
		||||
                conditioning_control_frames.append(False)
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,15 @@ class family_handler():
 | 
			
		||||
        extra_model_def["sliding_window"] = True
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = "TSEV"
 | 
			
		||||
 | 
			
		||||
        extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
            "selection": ["", "PV", "DV", "EV", "V"],
 | 
			
		||||
            "labels" : { "V": "Use LTXV raw format"}
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
            "selection": ["", "A", "NA", "XA", "XNA"],
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
 | 
			
		||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
 | 
			
		||||
from diffusers import FlowMatchEulerDiscreteScheduler
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions
 | 
			
		||||
from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
 | 
			
		||||
 | 
			
		||||
XLA_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
@ -563,6 +563,8 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
 | 
			
		||||
        max_sequence_length: int = 512,
 | 
			
		||||
        image = None,
 | 
			
		||||
        image_mask = None,
 | 
			
		||||
        denoising_strength = 0,
 | 
			
		||||
        callback=None,
 | 
			
		||||
        pipeline=None,
 | 
			
		||||
        loras_slists=None,
 | 
			
		||||
@ -694,14 +696,33 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            image_width = image_width // multiple_of * multiple_of
 | 
			
		||||
            image_height = image_height // multiple_of * multiple_of
 | 
			
		||||
            ref_height, ref_width = 1568, 672
 | 
			
		||||
            if height * width < ref_height * ref_width: ref_height , ref_width = height , width  
 | 
			
		||||
            if image_height * image_width > ref_height * ref_width:
 | 
			
		||||
                image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
 | 
			
		||||
            image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            if image_mask is None:
 | 
			
		||||
                if height * width < ref_height * ref_width: ref_height , ref_width = height , width  
 | 
			
		||||
                if image_height * image_width > ref_height * ref_width:
 | 
			
		||||
                    image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
                if (image_width,image_height) != image.size:
 | 
			
		||||
                    image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
                image_mask_latents = None
 | 
			
		||||
            else:
 | 
			
		||||
                # _, image_width, image_height = min(
 | 
			
		||||
                #     (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
 | 
			
		||||
                # )
 | 
			
		||||
                image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
                # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
 | 
			
		||||
                height, width = image_height, image_width
 | 
			
		||||
                image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
 | 
			
		||||
                image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
 | 
			
		||||
                image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
 | 
			
		||||
                convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
 | 
			
		||||
                image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
 | 
			
		||||
 | 
			
		||||
            prompt_image = image
 | 
			
		||||
            image = self.image_processor.preprocess(image, image_height, image_width)
 | 
			
		||||
            image = image.unsqueeze(2)
 | 
			
		||||
            if image.size != (image_width, image_height):
 | 
			
		||||
                image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
 | 
			
		||||
 | 
			
		||||
            image.save("nnn.png")
 | 
			
		||||
            image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
 | 
			
		||||
 | 
			
		||||
        has_neg_prompt = negative_prompt is not None or (
 | 
			
		||||
            negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
 | 
			
		||||
@ -744,6 +765,8 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            generator,
 | 
			
		||||
            latents,
 | 
			
		||||
        )
 | 
			
		||||
        original_image_latents = None if image_latents is None else image_latents.clone() 
 | 
			
		||||
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            img_shapes = [
 | 
			
		||||
                [
 | 
			
		||||
@ -788,6 +811,15 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
        negative_txt_seq_lens = (
 | 
			
		||||
            negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
 | 
			
		||||
        )
 | 
			
		||||
        morph = False
 | 
			
		||||
        if image_mask_latents is not None and denoising_strength <= 1.:
 | 
			
		||||
            first_step = int(len(timesteps) * (1. - denoising_strength))
 | 
			
		||||
            if not morph:
 | 
			
		||||
                latent_noise_factor = timesteps[first_step]/1000
 | 
			
		||||
                latents  = original_image_latents  * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor 
 | 
			
		||||
                timesteps = timesteps[first_step:]
 | 
			
		||||
                self.scheduler.timesteps = timesteps
 | 
			
		||||
                self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
 | 
			
		||||
 | 
			
		||||
        # 6. Denoising loop
 | 
			
		||||
        self.scheduler.set_begin_index(0)
 | 
			
		||||
@ -797,10 +829,15 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            update_loras_slists(self.transformer, loras_slists, updated_num_steps)
 | 
			
		||||
            callback(-1, None, True, override_num_inference_steps = updated_num_steps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        for i, t in enumerate(timesteps):
 | 
			
		||||
            if self.interrupt:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
 | 
			
		||||
                latent_noise_factor = t/1000
 | 
			
		||||
                latents  = original_image_latents  * (1.0 - latent_noise_factor) + latents * latent_noise_factor 
 | 
			
		||||
 | 
			
		||||
            self._current_timestep = t
 | 
			
		||||
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
 | 
			
		||||
            timestep = t.expand(latents.shape[0]).to(latents.dtype)
 | 
			
		||||
@ -865,6 +902,12 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
            # compute the previous noisy sample x_t -> x_t-1
 | 
			
		||||
            latents_dtype = latents.dtype
 | 
			
		||||
            latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
 | 
			
		||||
            if image_mask_latents is not None:
 | 
			
		||||
                next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
 | 
			
		||||
                latent_noise_factor = next_t / 1000
 | 
			
		||||
                noisy_image  = original_image_latents  * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor 
 | 
			
		||||
                latents  =  noisy_image * (1-image_mask_latents)  + image_mask_latents * latents
 | 
			
		||||
                noisy_image = None
 | 
			
		||||
 | 
			
		||||
            if latents.dtype != latents_dtype:
 | 
			
		||||
                if torch.backends.mps.is_available():
 | 
			
		||||
@ -878,7 +921,7 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
 | 
			
		||||
        self._current_timestep = None
 | 
			
		||||
        if output_type == "latent":
 | 
			
		||||
            image = latents
 | 
			
		||||
            output_image = latents
 | 
			
		||||
        else:
 | 
			
		||||
            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
 | 
			
		||||
            latents = latents.to(self.vae.dtype)
 | 
			
		||||
@ -891,7 +934,9 @@ class QwenImagePipeline(): #DiffusionPipeline
 | 
			
		||||
                latents.device, latents.dtype
 | 
			
		||||
            )
 | 
			
		||||
            latents = latents / latents_std + latents_mean
 | 
			
		||||
            image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
 | 
			
		||||
            output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
 | 
			
		||||
            if image_mask is not None:
 | 
			
		||||
                output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt 
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return image
 | 
			
		||||
        return output_image
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
 | 
			
		||||
class family_handler():
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_model_def(base_model_type, model_def):
 | 
			
		||||
        model_def_output = {
 | 
			
		||||
        extra_model_def = {
 | 
			
		||||
            "image_outputs" : True,
 | 
			
		||||
            "sample_solvers":[
 | 
			
		||||
                            ("Default", "default"),
 | 
			
		||||
@ -18,8 +18,18 @@ class family_handler():
 | 
			
		||||
            "lock_image_refs_ratios": True,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B"]: 
 | 
			
		||||
            extra_model_def["inpaint_support"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
            "choices": [
 | 
			
		||||
                ("None", ""),
 | 
			
		||||
                ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
 | 
			
		||||
                ("Conditional Images are People / Objects", "I"),
 | 
			
		||||
                ],
 | 
			
		||||
            "letters_filter": "KI",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        return model_def_output
 | 
			
		||||
        return extra_model_def
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def query_supported_types():
 | 
			
		||||
@ -75,14 +85,18 @@ class family_handler():
 | 
			
		||||
        if ui_defaults.get("sample_solver", "") == "": 
 | 
			
		||||
            ui_defaults["sample_solver"] = "default"
 | 
			
		||||
 | 
			
		||||
        if settings_version < 2.32:
 | 
			
		||||
            ui_defaults["denoising_strength"] = 1.
 | 
			
		||||
                            
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def update_default_settings(base_model_type, model_def, ui_defaults):
 | 
			
		||||
        ui_defaults.update({
 | 
			
		||||
            "guidance_scale":  4,
 | 
			
		||||
            "sample_solver": "default",
 | 
			
		||||
        })            
 | 
			
		||||
        if model_def.get("reference_image", False):
 | 
			
		||||
        if base_model_type in ["qwen_image_edit_20B"]: 
 | 
			
		||||
            ui_defaults.update({
 | 
			
		||||
                "video_prompt_type": "KI",
 | 
			
		||||
                "denoising_strength" : 1.,
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -103,6 +103,8 @@ class model_factory():
 | 
			
		||||
        n_prompt = None,
 | 
			
		||||
        sampling_steps: int = 20,
 | 
			
		||||
        input_ref_images = None,
 | 
			
		||||
        image_guide= None,
 | 
			
		||||
        image_mask= None,
 | 
			
		||||
        width= 832,
 | 
			
		||||
        height=480,
 | 
			
		||||
        guide_scale: float = 4,
 | 
			
		||||
@ -114,6 +116,7 @@ class model_factory():
 | 
			
		||||
        VAE_tile_size = None, 
 | 
			
		||||
        joint_pass = True,
 | 
			
		||||
        sample_solver='default',
 | 
			
		||||
        denoising_strength = 1.,
 | 
			
		||||
        **bbargs
 | 
			
		||||
    ):
 | 
			
		||||
        # Generate with different aspect ratios
 | 
			
		||||
@ -174,8 +177,9 @@ class model_factory():
 | 
			
		||||
 | 
			
		||||
        if n_prompt is None or len(n_prompt) == 0:
 | 
			
		||||
            n_prompt=  "text, watermark, copyright, blurry, low resolution"
 | 
			
		||||
 | 
			
		||||
        if input_ref_images is not None:
 | 
			
		||||
        if image_guide is not None:
 | 
			
		||||
            input_ref_images = [image_guide] 
 | 
			
		||||
        elif input_ref_images is not None:
 | 
			
		||||
            # image stiching method
 | 
			
		||||
            stiched = input_ref_images[0]
 | 
			
		||||
            if "K" in video_prompt_type :
 | 
			
		||||
@ -190,6 +194,7 @@ class model_factory():
 | 
			
		||||
            prompt=input_prompt,
 | 
			
		||||
            negative_prompt=n_prompt,
 | 
			
		||||
            image = input_ref_images,
 | 
			
		||||
            image_mask = image_mask,
 | 
			
		||||
            width=width,
 | 
			
		||||
            height=height,
 | 
			
		||||
            num_inference_steps=sampling_steps,
 | 
			
		||||
@ -199,6 +204,7 @@ class model_factory():
 | 
			
		||||
            pipeline=self,
 | 
			
		||||
            loras_slists=loras_slists,
 | 
			
		||||
            joint_pass = joint_pass,
 | 
			
		||||
            denoising_strength=denoising_strength,
 | 
			
		||||
            generator=torch.Generator(device="cuda").manual_seed(seed)
 | 
			
		||||
        )        
 | 
			
		||||
        if image is None: return None
 | 
			
		||||
 | 
			
		||||
@ -261,7 +261,7 @@ class WanAny2V:
 | 
			
		||||
    def vace_latent(self, z, m):
 | 
			
		||||
        return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
 | 
			
		||||
 | 
			
		||||
    def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
 | 
			
		||||
    def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
 | 
			
		||||
        from shared.utils.utils import save_image
 | 
			
		||||
        ref_width, ref_height = ref_img.size
 | 
			
		||||
        if (ref_height, ref_width) == image_size and outpainting_dims  == None:
 | 
			
		||||
@ -270,18 +270,23 @@ class WanAny2V:
 | 
			
		||||
        else:
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
                final_height, final_width = image_size
 | 
			
		||||
                canvas_height, canvas_width, margin_top, margin_left =   get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 8)        
 | 
			
		||||
                canvas_height, canvas_width, margin_top, margin_left =   get_outpainting_frame_location(final_height, final_width,  outpainting_dims, 1)        
 | 
			
		||||
            else:
 | 
			
		||||
                canvas_height, canvas_width = image_size
 | 
			
		||||
            scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
            new_height = int(ref_height * scale)
 | 
			
		||||
            new_width = int(ref_width * scale)
 | 
			
		||||
            if fill_max  and (canvas_height - new_height) < 16:
 | 
			
		||||
            if full_frame:
 | 
			
		||||
                new_height = canvas_height
 | 
			
		||||
            if fill_max  and (canvas_width - new_width) < 16:
 | 
			
		||||
                new_width = canvas_width
 | 
			
		||||
            top = (canvas_height - new_height) // 2
 | 
			
		||||
            left = (canvas_width - new_width) // 2
 | 
			
		||||
                top = left = 0 
 | 
			
		||||
            else:
 | 
			
		||||
                # if fill_max  and (canvas_height - new_height) < 16:
 | 
			
		||||
                #     new_height = canvas_height
 | 
			
		||||
                # if fill_max  and (canvas_width - new_width) < 16:
 | 
			
		||||
                #     new_width = canvas_width
 | 
			
		||||
                scale = min(canvas_height / ref_height, canvas_width / ref_width)
 | 
			
		||||
                new_height = int(ref_height * scale)
 | 
			
		||||
                new_width = int(ref_width * scale)
 | 
			
		||||
                top = (canvas_height - new_height) // 2
 | 
			
		||||
                left = (canvas_width - new_width) // 2
 | 
			
		||||
            ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
 | 
			
		||||
            if outpainting_dims != None:
 | 
			
		||||
@ -302,7 +307,7 @@ class WanAny2V:
 | 
			
		||||
                canvas = canvas.to(device)
 | 
			
		||||
        return ref_img.to(device), canvas
 | 
			
		||||
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, keep_video_guide_frames= [], start_frame = 0,  fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
 | 
			
		||||
    def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size,  device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
 | 
			
		||||
        image_sizes = []
 | 
			
		||||
        trim_video_guide = len(keep_video_guide_frames)
 | 
			
		||||
        def conv_tensor(t, device):
 | 
			
		||||
@ -533,22 +538,16 @@ class WanAny2V:
 | 
			
		||||
            any_end_frame = False
 | 
			
		||||
            if image_start is None:
 | 
			
		||||
                if infinitetalk:
 | 
			
		||||
                    new_shot = "Q" in video_prompt_type
 | 
			
		||||
                    if input_frames is not None:
 | 
			
		||||
                        image_ref = input_frames[:, 0]
 | 
			
		||||
                        if input_video is None: input_video = input_frames[:, 0:1]
 | 
			
		||||
                        new_shot = "Q" in video_prompt_type
 | 
			
		||||
                    else:
 | 
			
		||||
                        if pre_video_frame is None:
 | 
			
		||||
                            new_shot = True
 | 
			
		||||
                        else:
 | 
			
		||||
                            if input_ref_images is None:
 | 
			
		||||
                                input_ref_images, new_shot = [pre_video_frame], False
 | 
			
		||||
                            else:
 | 
			
		||||
                                input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
 | 
			
		||||
                        if input_ref_images is None: raise Exception("Missing Reference Image")
 | 
			
		||||
                        if input_ref_images is None:                        
 | 
			
		||||
                            if pre_video_frame is None: raise Exception("Missing Reference Image")
 | 
			
		||||
                            input_ref_images = [pre_video_frame]
 | 
			
		||||
                        new_shot = new_shot and window_no <= len(input_ref_images)
 | 
			
		||||
                        image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
 | 
			
		||||
                    if new_shot:  
 | 
			
		||||
                    if new_shot or input_video is None:  
 | 
			
		||||
                        input_video = image_ref.unsqueeze(1)
 | 
			
		||||
                    else:
 | 
			
		||||
                        color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,7 @@ class family_handler():
 | 
			
		||||
                    "label" : "Generation Type"
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = "TSEV"
 | 
			
		||||
        extra_model_def["image_prompt_types_allowed"] = "TSV"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return extra_model_def 
 | 
			
		||||
 | 
			
		||||
@ -110,19 +110,79 @@ class family_handler():
 | 
			
		||||
        "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
 | 
			
		||||
        "mag_cache" : True,
 | 
			
		||||
        "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
 | 
			
		||||
        "convert_image_guide_to_video" : True,
 | 
			
		||||
        "sample_solvers":[
 | 
			
		||||
                            ("unipc", "unipc"),
 | 
			
		||||
                            ("euler", "euler"),
 | 
			
		||||
                            ("dpm++", "dpm++"),
 | 
			
		||||
                            ("flowmatch causvid", "causvid"), ]
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["t2v"]: 
 | 
			
		||||
            extra_model_def["guide_custom_choices"] = {
 | 
			
		||||
                "choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")],
 | 
			
		||||
                "default": "",
 | 
			
		||||
                "letters_filter": "GUV",
 | 
			
		||||
                "label": "Video to Video"
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["infinitetalk"]: 
 | 
			
		||||
            extra_model_def["no_background_removal"] = True
 | 
			
		||||
            # extra_model_def["at_least_one_image_ref_needed"] = True
 | 
			
		||||
            extra_model_def["all_image_refs_are_background_ref"] = True
 | 
			
		||||
            extra_model_def["guide_custom_choices"] = {
 | 
			
		||||
            "choices":[
 | 
			
		||||
                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
 | 
			
		||||
                ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
 | 
			
		||||
                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
 | 
			
		||||
                ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
 | 
			
		||||
                ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
 | 
			
		||||
                ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
 | 
			
		||||
            ],
 | 
			
		||||
            "default": "KI",
 | 
			
		||||
            "letters_filter": "RGUVQKI",
 | 
			
		||||
            "label": "Video to Video",
 | 
			
		||||
            "show_label" : False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            # extra_model_def["at_least_one_image_ref_needed"] = True
 | 
			
		||||
        if vace_class:
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
 | 
			
		||||
                    "labels" : { "V": "Use Vace raw format"}
 | 
			
		||||
                }
 | 
			
		||||
            extra_model_def["mask_preprocessing"] = {
 | 
			
		||||
                    "selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                    "choices": [("None", ""),
 | 
			
		||||
                    ("Inject only People / Objects", "I"),
 | 
			
		||||
                    ("Inject Landscape and then People / Objects", "KI"),
 | 
			
		||||
                    ("Inject Frames and then People / Objects", "FI"),
 | 
			
		||||
                    ],
 | 
			
		||||
                    "letters_filter":  "KFI",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["standin"] or vace_class: 
 | 
			
		||||
            extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["standin"]: 
 | 
			
		||||
            extra_model_def["lock_image_refs_ratios"] = True
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [
 | 
			
		||||
                    ("No Reference Image", ""),
 | 
			
		||||
                    ("Reference Image is a Person Face", "I"),
 | 
			
		||||
                    ],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["phantom_1.3B", "phantom_14B"]: 
 | 
			
		||||
            extra_model_def["image_ref_choices"] = {
 | 
			
		||||
                "choices": [("Reference Image", "I")],
 | 
			
		||||
                "letters_filter":"I",
 | 
			
		||||
                "visible": False,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        if base_model_type in ["recam_1.3B"]: 
 | 
			
		||||
            extra_model_def["keep_frames_video_guide_not_supported"] = True
 | 
			
		||||
            extra_model_def["model_modes"] = {
 | 
			
		||||
@ -141,10 +201,18 @@ class family_handler():
 | 
			
		||||
                        "default": 1,
 | 
			
		||||
                        "label" : "Camera Movement Type"
 | 
			
		||||
            }
 | 
			
		||||
            extra_model_def["guide_preprocessing"] = {
 | 
			
		||||
                    "selection": ["UV"],
 | 
			
		||||
                    "labels" : { "UV": "Control Video"},
 | 
			
		||||
                    "visible" : False,
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
        if vace_class or base_model_type in ["infinitetalk"]:
 | 
			
		||||
            image_prompt_types_allowed = "TVL"
 | 
			
		||||
        elif base_model_type in ["ti2v_2_2"]:
 | 
			
		||||
            image_prompt_types_allowed = "TSEVL"
 | 
			
		||||
            image_prompt_types_allowed = "TSVL"
 | 
			
		||||
        elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
 | 
			
		||||
            image_prompt_types_allowed = "SVL"
 | 
			
		||||
        elif i2v:
 | 
			
		||||
            image_prompt_types_allowed = "SEVL"
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,6 @@ import psutil
 | 
			
		||||
# import ffmpeg
 | 
			
		||||
import imageio
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
import cv2
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
@ -33,6 +32,8 @@ model_in_GPU = False
 | 
			
		||||
matanyone_in_GPU = False
 | 
			
		||||
bfloat16_supported = False
 | 
			
		||||
# SAM generator
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
class MaskGenerator():
 | 
			
		||||
    def __init__(self, sam_checkpoint, device):
 | 
			
		||||
        global args_device
 | 
			
		||||
@ -89,6 +90,7 @@ def get_frames_from_image(image_input, image_state):
 | 
			
		||||
        "last_frame_numer": 0,
 | 
			
		||||
        "fps": None
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
    image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
 | 
			
		||||
    set_image_encoder_patch()
 | 
			
		||||
    select_SAM()
 | 
			
		||||
@ -717,27 +719,33 @@ def load_unload_models(selected):
 | 
			
		||||
def get_vmc_event_handler():
 | 
			
		||||
    return load_unload_models
 | 
			
		||||
 | 
			
		||||
def export_to_vace_video_input(foreground_video_output):
 | 
			
		||||
    gr.Info("Masked Video Input transferred to Vace For Inpainting")
 | 
			
		||||
    return "V#" + str(time.time()), foreground_video_output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_image(image_refs, image_output):
 | 
			
		||||
    gr.Info("Masked Image transferred to Current Video")
 | 
			
		||||
def export_image(state, image_output):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    image_refs = ui_settings["image_refs"]
 | 
			
		||||
    if image_refs == None:
 | 
			
		||||
        image_refs =[]
 | 
			
		||||
    image_refs.append( image_output)
 | 
			
		||||
    return image_refs
 | 
			
		||||
    ui_settings["image_refs"] = image_refs 
 | 
			
		||||
    gr.Info("Masked Image transferred to Current Image Generator")
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
def export_image_mask(image_input, image_mask):
 | 
			
		||||
    gr.Info("Input Image & Mask transferred to Current Video")
 | 
			
		||||
    return Image.fromarray(image_input), image_mask
 | 
			
		||||
def export_image_mask(state, image_input, image_mask):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    ui_settings["image_guide"] = Image.fromarray(image_input)
 | 
			
		||||
    ui_settings["image_mask"] = image_mask
 | 
			
		||||
 | 
			
		||||
    gr.Info("Input Image & Mask transferred to Current Image Generator")
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_to_current_video_engine( foreground_video_output, alpha_video_output):
 | 
			
		||||
def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
 | 
			
		||||
    ui_settings = get_current_model_settings(state)
 | 
			
		||||
    ui_settings["video_guide"] = foreground_video_output
 | 
			
		||||
    ui_settings["video_mask"] = alpha_video_output
 | 
			
		||||
 | 
			
		||||
    gr.Info("Original Video and Full Mask have been transferred")
 | 
			
		||||
    # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
 | 
			
		||||
    return foreground_video_output, alpha_video_output
 | 
			
		||||
    return time.time()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def teleport_to_video_tab(tab_state):
 | 
			
		||||
@ -746,9 +754,10 @@ def teleport_to_video_tab(tab_state):
 | 
			
		||||
    return gr.Tabs(selected="video_gen")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
 | 
			
		||||
def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #,  vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
 | 
			
		||||
    # my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
 | 
			
		||||
    global image_output_codec, video_output_codec
 | 
			
		||||
    global image_output_codec, video_output_codec, get_current_model_settings
 | 
			
		||||
    get_current_model_settings = get_current_model_settings_fn
 | 
			
		||||
 | 
			
		||||
    image_output_codec = server_config.get("image_output_codec", None)
 | 
			
		||||
    video_output_codec = server_config.get("video_output_codec", None)
 | 
			
		||||
@ -871,7 +880,7 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                            template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
                                clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False,  min_width=100)
 | 
			
		||||
                                add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
 | 
			
		||||
                                add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
 | 
			
		||||
                                remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False,  min_width=100) # no use
 | 
			
		||||
                                matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False,  min_width=100)
 | 
			
		||||
                            with gr.Row():
 | 
			
		||||
@ -892,7 +901,7 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                            with gr.Row(visible= True):
 | 
			
		||||
                                export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
 | 
			
		||||
                                    
 | 
			
		||||
                export_to_current_video_engine_btn.click(  fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                export_to_current_video_engine_btn.click(  fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1089,9 +1098,9 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                    # with gr.Column(scale=2, visible= True):
 | 
			
		||||
                        export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
 | 
			
		||||
 | 
			
		||||
                export_image_btn.click(  fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                export_image_btn.click(  fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
                export_image_mask_btn.click(  fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                export_image_mask_btn.click(  fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, 
 | 
			
		||||
                    fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
 | 
			
		||||
 | 
			
		||||
                # first step: get the image information 
 | 
			
		||||
@ -1148,5 +1157,21 @@ def display(tabs, tab_state, server_config,  vace_video_input, vace_image_input,
 | 
			
		||||
                    outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                nada = gr.State({})
 | 
			
		||||
                # clear input
 | 
			
		||||
                gr.on(
 | 
			
		||||
                    triggers=[image_input.clear], #image_input.change,
 | 
			
		||||
                    fn=restart,
 | 
			
		||||
                    inputs=[],
 | 
			
		||||
                    outputs=[ 
 | 
			
		||||
                        image_state,
 | 
			
		||||
                        interactive_state,
 | 
			
		||||
                        click_state,
 | 
			
		||||
                        foreground_image_output, alpha_image_output,
 | 
			
		||||
                        template_frame,
 | 
			
		||||
                        image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click, 
 | 
			
		||||
                        add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
 | 
			
		||||
                    ],
 | 
			
		||||
                    queue=False,
 | 
			
		||||
                    show_progress=False)
 | 
			
		||||
                
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ librosa==0.11.0
 | 
			
		||||
speechbrain==1.0.3
 | 
			
		||||
 
 | 
			
		||||
# UI & interaction
 | 
			
		||||
gradio==5.23.0
 | 
			
		||||
gradio==5.29.0
 | 
			
		||||
dashscope
 | 
			
		||||
loguru
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
 | 
			
		||||
 | 
			
		||||
import gradio as gr
 | 
			
		||||
import PIL
 | 
			
		||||
import time
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
 | 
			
		||||
FilePath = str
 | 
			
		||||
@ -20,6 +21,9 @@ def get_list( objs):
 | 
			
		||||
        return []
 | 
			
		||||
    return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
 | 
			
		||||
 | 
			
		||||
def record_last_action(st, last_action):
 | 
			
		||||
    st["last_action"] = last_action
 | 
			
		||||
    st["last_time"] = time.time()
 | 
			
		||||
class AdvancedMediaGallery:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
@ -60,9 +64,10 @@ class AdvancedMediaGallery:
 | 
			
		||||
        self.state: Optional[gr.State] = None
 | 
			
		||||
        self._initial_state: Dict[str, Any] = {
 | 
			
		||||
            "items": items,
 | 
			
		||||
            "selected": (len(items) - 1) if items else None,
 | 
			
		||||
            "selected": (len(items) - 1) if items else 0, # None,
 | 
			
		||||
            "single": bool(single_image_mode),
 | 
			
		||||
            "mode": self.media_mode,
 | 
			
		||||
            "last_action": "",
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    # ---------------- helpers ----------------
 | 
			
		||||
@ -210,6 +215,13 @@ class AdvancedMediaGallery:
 | 
			
		||||
 | 
			
		||||
    def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
 | 
			
		||||
        # Mirror the selected index into state and the gallery (server-side selected_index)
 | 
			
		||||
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        last_time = st.get("last_time", None)
 | 
			
		||||
        if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
 | 
			
		||||
            # print(f"ignored:{time.time()}, real {st['selected']}")
 | 
			
		||||
            return gr.update(selected_index=st["selected"]), st
 | 
			
		||||
 | 
			
		||||
        idx = None
 | 
			
		||||
        if evt is not None and hasattr(evt, "index"):
 | 
			
		||||
            ix = evt.index
 | 
			
		||||
@ -220,17 +232,28 @@ class AdvancedMediaGallery:
 | 
			
		||||
                    idx = ix[0] * max(1, int(self.columns)) + ix[1]
 | 
			
		||||
                else:
 | 
			
		||||
                    idx = ix[0]
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        n = len(get_list(gallery))
 | 
			
		||||
        sel = idx if (idx is not None and 0 <= idx < n) else None
 | 
			
		||||
        # print(f"image selected evt index:{sel}/{evt.selected}")
 | 
			
		||||
        st["selected"] = sel
 | 
			
		||||
        # return gr.update(selected_index=sel), st
 | 
			
		||||
        # return gr.update(), st
 | 
			
		||||
        return st
 | 
			
		||||
        return gr.update(), st
 | 
			
		||||
 | 
			
		||||
    def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
 | 
			
		||||
        # Fires when users upload via the Gallery itself.
 | 
			
		||||
        # items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        items_filtered = list(value or [])
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        new_items = self._paths_from_payload(items_filtered)
 | 
			
		||||
        st["items"] = new_items
 | 
			
		||||
        new_sel = len(new_items) - 1
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        record_last_action(st,"add")
 | 
			
		||||
        return gr.update(selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
 | 
			
		||||
        # Fires when users add/drag/drop/delete via the Gallery itself.
 | 
			
		||||
        items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        # items_filtered = self._filter_items_by_mode(list(value or []))
 | 
			
		||||
        items_filtered = list(value or [])
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        st["items"] = items_filtered
 | 
			
		||||
        # Keep selection if still valid, else default to last
 | 
			
		||||
@ -240,10 +263,9 @@ class AdvancedMediaGallery:
 | 
			
		||||
        else:
 | 
			
		||||
            new_sel = old_sel
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        # return gr.update(value=items_filtered, selected_index=new_sel), st
 | 
			
		||||
        # return gr.update(value=items_filtered), st
 | 
			
		||||
 | 
			
		||||
        return gr.update(), st
 | 
			
		||||
        st["last_action"] ="gallery_change"
 | 
			
		||||
        # print(f"gallery change: set sel {new_sel}")
 | 
			
		||||
        return gr.update(selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
 | 
			
		||||
        """
 | 
			
		||||
@ -252,7 +274,8 @@ class AdvancedMediaGallery:
 | 
			
		||||
        and re-selects the last inserted item.
 | 
			
		||||
        """
 | 
			
		||||
        # New items (respect image/video mode)
 | 
			
		||||
        new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
 | 
			
		||||
        # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
 | 
			
		||||
        new_items = self._paths_from_payload(files_payload)
 | 
			
		||||
 | 
			
		||||
        st = get_state(state)
 | 
			
		||||
        cur: List[Any] = get_list(gallery)
 | 
			
		||||
@ -298,30 +321,6 @@ class AdvancedMediaGallery:
 | 
			
		||||
                if k is not None:
 | 
			
		||||
                    seen_new.add(k)
 | 
			
		||||
 | 
			
		||||
        # Remove any existing occurrences of the incoming items from current list,
 | 
			
		||||
        # BUT keep the currently selected item even if it's also in incoming.
 | 
			
		||||
        cur_clean: List[Any] = []
 | 
			
		||||
        # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
 | 
			
		||||
        # for idx, it in enumerate(cur):
 | 
			
		||||
        #     k = key_of(it)
 | 
			
		||||
        #     if it is sel_item:
 | 
			
		||||
        #         cur_clean.append(it)
 | 
			
		||||
        #         continue
 | 
			
		||||
        #     if k is not None and k in seen_new:
 | 
			
		||||
        #         continue  # drop duplicate; we'll reinsert at the target spot
 | 
			
		||||
        #     cur_clean.append(it)
 | 
			
		||||
 | 
			
		||||
        # # Compute insertion position: right AFTER the (possibly shifted) selected item
 | 
			
		||||
        # if sel_item is not None:
 | 
			
		||||
        #     # find sel_item's new index in cur_clean
 | 
			
		||||
        #     try:
 | 
			
		||||
        #         pos_sel = cur_clean.index(sel_item)
 | 
			
		||||
        #     except ValueError:
 | 
			
		||||
        #         # Shouldn't happen, but fall back to end
 | 
			
		||||
        #         pos_sel = len(cur_clean) - 1
 | 
			
		||||
        #     insert_pos = pos_sel + 1
 | 
			
		||||
        # else:
 | 
			
		||||
        #     insert_pos = len(cur_clean)  # no selection -> append at end
 | 
			
		||||
        insert_pos = min(sel, len(cur) -1)
 | 
			
		||||
        cur_clean = cur
 | 
			
		||||
        # Build final list and selection
 | 
			
		||||
@ -330,6 +329,8 @@ class AdvancedMediaGallery:
 | 
			
		||||
 | 
			
		||||
        st["items"] = merged
 | 
			
		||||
        st["selected"] = new_sel
 | 
			
		||||
        record_last_action(st,"add")
 | 
			
		||||
        # print(f"gallery add: set sel {new_sel}")
 | 
			
		||||
        return gr.update(value=merged, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_remove(self, state: Dict[str, Any], gallery) :
 | 
			
		||||
@ -342,8 +343,9 @@ class AdvancedMediaGallery:
 | 
			
		||||
            return gr.update(value=[], selected_index=None), st
 | 
			
		||||
        new_sel = min(sel, len(items) - 1)
 | 
			
		||||
        st["items"] = items; st["selected"] = new_sel
 | 
			
		||||
        # return gr.update(value=items, selected_index=new_sel), st
 | 
			
		||||
        return gr.update(value=items), st
 | 
			
		||||
        record_last_action(st,"remove")
 | 
			
		||||
        # print(f"gallery del: new sel {new_sel}")
 | 
			
		||||
        return gr.update(value=items, selected_index=new_sel), st
 | 
			
		||||
 | 
			
		||||
    def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
 | 
			
		||||
        st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
 | 
			
		||||
@ -354,11 +356,15 @@ class AdvancedMediaGallery:
 | 
			
		||||
            return gr.update(value=items, selected_index=sel), st
 | 
			
		||||
        items[sel], items[j] = items[j], items[sel]
 | 
			
		||||
        st["items"] = items; st["selected"] = j
 | 
			
		||||
        record_last_action(st,"move")
 | 
			
		||||
        # print(f"gallery move: set sel {j}")
 | 
			
		||||
        return gr.update(value=items, selected_index=j), st
 | 
			
		||||
 | 
			
		||||
    def _on_clear(self, state: Dict[str, Any]) :
 | 
			
		||||
        st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
 | 
			
		||||
        return gr.update(value=[], selected_index=0), st
 | 
			
		||||
        record_last_action(st,"clear")
 | 
			
		||||
        # print(f"Clear all")
 | 
			
		||||
        return gr.update(value=[], selected_index=None), st
 | 
			
		||||
 | 
			
		||||
    def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
 | 
			
		||||
        st = get_state(state); st["single"] = bool(to_single)
 | 
			
		||||
@ -382,30 +388,38 @@ class AdvancedMediaGallery:
 | 
			
		||||
    def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
 | 
			
		||||
        if parent is not None:
 | 
			
		||||
            with parent:
 | 
			
		||||
                col = self._build_ui()
 | 
			
		||||
                col = self._build_ui(update_form)
 | 
			
		||||
        else:
 | 
			
		||||
            col = self._build_ui()
 | 
			
		||||
            col = self._build_ui(update_form)
 | 
			
		||||
        if not update_form:
 | 
			
		||||
            self._wire_events()
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
    def _build_ui(self) -> gr.Column:
 | 
			
		||||
    def _build_ui(self, update = False) -> gr.Column:
 | 
			
		||||
        with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
 | 
			
		||||
            self.container = col
 | 
			
		||||
 | 
			
		||||
            self.state = gr.State(dict(self._initial_state))
 | 
			
		||||
 | 
			
		||||
            self.gallery = gr.Gallery(
 | 
			
		||||
                label=self.label,
 | 
			
		||||
                value=self._initial_state["items"],
 | 
			
		||||
                height=self.height,
 | 
			
		||||
                columns=self.columns,
 | 
			
		||||
                show_label=self.show_label,
 | 
			
		||||
                preview= True,
 | 
			
		||||
                # type="pil",
 | 
			
		||||
                file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), 
 | 
			
		||||
                selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
            )
 | 
			
		||||
            if update:
 | 
			
		||||
                self.gallery = gr.update(
 | 
			
		||||
                    value=self._initial_state["items"],
 | 
			
		||||
                    selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
                    label=self.label,
 | 
			
		||||
                    show_label=self.show_label,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.gallery = gr.Gallery(
 | 
			
		||||
                    value=self._initial_state["items"],
 | 
			
		||||
                    label=self.label,
 | 
			
		||||
                    height=self.height,
 | 
			
		||||
                    columns=self.columns,
 | 
			
		||||
                    show_label=self.show_label,
 | 
			
		||||
                    preview= True,
 | 
			
		||||
                    # type="pil", # very slow
 | 
			
		||||
                    file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), 
 | 
			
		||||
                    selected_index=self._initial_state["selected"],  # server-side selection
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # One-line controls
 | 
			
		||||
            exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
 | 
			
		||||
@ -418,10 +432,10 @@ class AdvancedMediaGallery:
 | 
			
		||||
                    size="sm",
 | 
			
		||||
                    min_width=1,
 | 
			
		||||
                )
 | 
			
		||||
                self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
 | 
			
		||||
                self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
 | 
			
		||||
                self.btn_left   = gr.Button("◀ Left",  size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_right  = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_clear  = gr.Button("Clear",   variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
                self.btn_clear  = gr.Button(" Clear ",   variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
 | 
			
		||||
 | 
			
		||||
        return col
 | 
			
		||||
 | 
			
		||||
@ -430,14 +444,24 @@ class AdvancedMediaGallery:
 | 
			
		||||
        self.gallery.select(
 | 
			
		||||
            self._on_select,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
 | 
			
		||||
        self.gallery.change(
 | 
			
		||||
        self.gallery.upload(
 | 
			
		||||
            self._on_upload,
 | 
			
		||||
            inputs=[self.gallery, self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
 | 
			
		||||
        self.gallery.upload(
 | 
			
		||||
            self._on_gallery_change,
 | 
			
		||||
            inputs=[self.gallery, self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Add via UploadButton
 | 
			
		||||
@ -445,6 +469,7 @@ class AdvancedMediaGallery:
 | 
			
		||||
            self._on_add,
 | 
			
		||||
            inputs=[self.upload_btn, self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Remove selected
 | 
			
		||||
@ -452,6 +477,7 @@ class AdvancedMediaGallery:
 | 
			
		||||
            self._on_remove,
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Reorder using selected index, keep same item selected
 | 
			
		||||
@ -459,11 +485,13 @@ class AdvancedMediaGallery:
 | 
			
		||||
            lambda st, gallery: self._on_move(-1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
        self.btn_right.click(
 | 
			
		||||
            lambda st, gallery: self._on_move(+1, st, gallery),
 | 
			
		||||
            inputs=[self.state, self.gallery],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Clear all
 | 
			
		||||
@ -471,6 +499,7 @@ class AdvancedMediaGallery:
 | 
			
		||||
            self._on_clear,
 | 
			
		||||
            inputs=[self.state],
 | 
			
		||||
            outputs=[self.gallery, self.state],
 | 
			
		||||
            trigger_mode="always_last",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # ---------------- public API ----------------
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ import tempfile
 | 
			
		||||
import subprocess
 | 
			
		||||
import json
 | 
			
		||||
from functools import lru_cache
 | 
			
		||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
@ -207,30 +208,62 @@ def  get_outpainting_frame_location(final_height, final_width,  outpainting_dims
 | 
			
		||||
    if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
 | 
			
		||||
    return height, width, margin_top, margin_left
 | 
			
		||||
 | 
			
		||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
 | 
			
		||||
    if fit_into_canvas == None:
 | 
			
		||||
def rescale_and_crop(img, w, h):
 | 
			
		||||
    ow, oh = img.size
 | 
			
		||||
    target_ratio = w / h
 | 
			
		||||
    orig_ratio = ow / oh
 | 
			
		||||
    
 | 
			
		||||
    if orig_ratio > target_ratio:
 | 
			
		||||
        # Crop width first
 | 
			
		||||
        nw = int(oh * target_ratio)
 | 
			
		||||
        img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
 | 
			
		||||
    else:
 | 
			
		||||
        # Crop height first
 | 
			
		||||
        nh = int(ow / target_ratio)
 | 
			
		||||
        img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
 | 
			
		||||
    
 | 
			
		||||
    return img.resize((w, h), Image.LANCZOS)
 | 
			
		||||
 | 
			
		||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas,  block_size = 16):
 | 
			
		||||
    if fit_into_canvas == None or fit_into_canvas == 2:
 | 
			
		||||
        # return image_height, image_width
 | 
			
		||||
        return canvas_height, canvas_width
 | 
			
		||||
    if fit_into_canvas:
 | 
			
		||||
    if fit_into_canvas == 1:
 | 
			
		||||
        scale1  = min(canvas_height / image_height, canvas_width / image_width)
 | 
			
		||||
        scale2  = min(canvas_width / image_height, canvas_height / image_width)
 | 
			
		||||
        scale = max(scale1, scale2) 
 | 
			
		||||
    else:
 | 
			
		||||
    else: #0 or #2 (crop)
 | 
			
		||||
        scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
 | 
			
		||||
 | 
			
		||||
    new_height = round( image_height * scale / block_size) * block_size
 | 
			
		||||
    new_width = round( image_width * scale / block_size) * block_size
 | 
			
		||||
    return new_height, new_width
 | 
			
		||||
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
 | 
			
		||||
def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
 | 
			
		||||
    if fit_crop:
 | 
			
		||||
        image = rescale_and_crop(image, canvas_width, canvas_height)
 | 
			
		||||
        new_width, new_height = image.size  
 | 
			
		||||
    else:
 | 
			
		||||
        image_width, image_height = image.size
 | 
			
		||||
        new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
 | 
			
		||||
        image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
    return image, new_height, new_width
 | 
			
		||||
 | 
			
		||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
 | 
			
		||||
    if rm_background:
 | 
			
		||||
        session = new_session() 
 | 
			
		||||
 | 
			
		||||
    output_list =[]
 | 
			
		||||
    for i, img in enumerate(img_list):
 | 
			
		||||
        width, height =  img.size 
 | 
			
		||||
 | 
			
		||||
        if fit_into_canvas:
 | 
			
		||||
        if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
 | 
			
		||||
            if outpainting_dims is not None:
 | 
			
		||||
                resized_image =img 
 | 
			
		||||
            elif img.size != (budget_width, budget_height):
 | 
			
		||||
                resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
            else:
 | 
			
		||||
                resized_image =img
 | 
			
		||||
        elif fit_into_canvas == 1:
 | 
			
		||||
            white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 
 | 
			
		||||
            scale = min(budget_height / height, budget_width / width)
 | 
			
		||||
            new_height = int(height * scale)
 | 
			
		||||
@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
 | 
			
		||||
            resized_image = Image.fromarray(white_canvas)  
 | 
			
		||||
        else:
 | 
			
		||||
            scale = (budget_height * budget_width / (height * width))**(1/2)
 | 
			
		||||
            new_height = int( round(height * scale / 16) * 16)
 | 
			
		||||
            new_width = int( round(width * scale / 16) * 16)
 | 
			
		||||
            new_height = int( round(height * scale / block_size) * block_size)
 | 
			
		||||
            new_width = int( round(width * scale / block_size) * block_size)
 | 
			
		||||
            resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) 
 | 
			
		||||
        if rm_background  and not (ignore_first and i == 0) :
 | 
			
		||||
        if rm_background  and not (any_background_ref and i==0 or any_background_ref == 2) :
 | 
			
		||||
            # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
            resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
 | 
			
		||||
        output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user