mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
intermediate commit
This commit is contained in:
parent
99fd9aea32
commit
f9f63cbc79
@ -7,8 +7,6 @@
|
|||||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
||||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||||
],
|
],
|
||||||
"image_outputs": true,
|
|
||||||
"reference_image": true,
|
|
||||||
"flux-model": "flux-dev-kontext"
|
"flux-model": "flux-dev-kontext"
|
||||||
},
|
},
|
||||||
"prompt": "add a hat",
|
"prompt": "add a hat",
|
||||||
|
|||||||
@ -6,8 +6,6 @@
|
|||||||
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
|
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
|
||||||
"URLs": "flux",
|
"URLs": "flux",
|
||||||
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
|
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
|
||||||
"image_outputs": true,
|
|
||||||
"reference_image": true,
|
|
||||||
"flux-model": "flux-dev-uso"
|
"flux-model": "flux-dev-uso"
|
||||||
},
|
},
|
||||||
"prompt": "the man is wearing a hat",
|
"prompt": "the man is wearing a hat",
|
||||||
|
|||||||
@ -9,9 +9,7 @@
|
|||||||
],
|
],
|
||||||
"attention": {
|
"attention": {
|
||||||
"<89": "sdpa"
|
"<89": "sdpa"
|
||||||
},
|
}
|
||||||
"reference_image": true,
|
|
||||||
"image_outputs": true
|
|
||||||
},
|
},
|
||||||
"prompt": "add a hat",
|
"prompt": "add a hat",
|
||||||
"resolution": "1280x720",
|
"resolution": "1280x720",
|
||||||
|
|||||||
@ -13,28 +13,41 @@ class family_handler():
|
|||||||
flux_schnell = flux_model == "flux-schnell"
|
flux_schnell = flux_model == "flux-schnell"
|
||||||
flux_chroma = flux_model == "flux-chroma"
|
flux_chroma = flux_model == "flux-chroma"
|
||||||
flux_uso = flux_model == "flux-dev-uso"
|
flux_uso = flux_model == "flux-dev-uso"
|
||||||
model_def_output = {
|
flux_kontext = flux_model == "flux-dev-kontext"
|
||||||
|
|
||||||
|
extra_model_def = {
|
||||||
"image_outputs" : True,
|
"image_outputs" : True,
|
||||||
"no_negative_prompt" : not flux_chroma,
|
"no_negative_prompt" : not flux_chroma,
|
||||||
}
|
}
|
||||||
if flux_chroma:
|
if flux_chroma:
|
||||||
model_def_output["guidance_max_phases"] = 1
|
extra_model_def["guidance_max_phases"] = 1
|
||||||
elif not flux_schnell:
|
elif not flux_schnell:
|
||||||
model_def_output["embedded_guidance"] = True
|
extra_model_def["embedded_guidance"] = True
|
||||||
if flux_uso :
|
if flux_uso :
|
||||||
model_def_output["any_image_refs_relative_size"] = True
|
extra_model_def["any_image_refs_relative_size"] = True
|
||||||
model_def_output["no_background_removal"] = True
|
extra_model_def["no_background_removal"] = True
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
model_def_output["image_ref_choices"] = {
|
|
||||||
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
|
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
|
||||||
("Up to two Images are Style Images", "KIJ")],
|
("Up to two Images are Style Images", "KIJ")],
|
||||||
"default": "KI",
|
"default": "KI",
|
||||||
"letters_filter": "KIJ",
|
"letters_filter": "KIJ",
|
||||||
"label": "Reference Images / Style Images"
|
"label": "Reference Images / Style Images"
|
||||||
}
|
}
|
||||||
model_def_output["lock_image_refs_ratios"] = True
|
|
||||||
|
if flux_kontext:
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [
|
||||||
|
("None", ""),
|
||||||
|
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
|
||||||
|
("Conditional Images are People / Objects", "I"),
|
||||||
|
],
|
||||||
|
"letters_filter": "KI",
|
||||||
|
}
|
||||||
|
|
||||||
return model_def_output
|
|
||||||
|
extra_model_def["lock_image_refs_ratios"] = True
|
||||||
|
|
||||||
|
return extra_model_def
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_supported_types():
|
def query_supported_types():
|
||||||
@ -122,10 +135,12 @@ class family_handler():
|
|||||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||||
flux_model = model_def.get("flux-model", "flux-dev")
|
flux_model = model_def.get("flux-model", "flux-dev")
|
||||||
flux_uso = flux_model == "flux-dev-uso"
|
flux_uso = flux_model == "flux-dev-uso"
|
||||||
|
flux_kontext = flux_model == "flux-dev-kontext"
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"embedded_guidance": 2.5,
|
"embedded_guidance": 2.5,
|
||||||
})
|
})
|
||||||
if model_def.get("reference_image", False):
|
|
||||||
|
if flux_kontext or flux_uso:
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"video_prompt_type": "KI",
|
"video_prompt_type": "KI",
|
||||||
})
|
})
|
||||||
|
|||||||
@ -24,44 +24,6 @@ from .util import (
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
|
|
||||||
target_height_ref1 = int(target_height_ref1 // 64 * 64)
|
|
||||||
target_width_ref1 = int(target_width_ref1 // 64 * 64)
|
|
||||||
h, w = image.shape[-2:]
|
|
||||||
if h < target_height_ref1 or w < target_width_ref1:
|
|
||||||
# 计算长宽比
|
|
||||||
aspect_ratio = w / h
|
|
||||||
if h < target_height_ref1:
|
|
||||||
new_h = target_height_ref1
|
|
||||||
new_w = new_h * aspect_ratio
|
|
||||||
if new_w < target_width_ref1:
|
|
||||||
new_w = target_width_ref1
|
|
||||||
new_h = new_w / aspect_ratio
|
|
||||||
else:
|
|
||||||
new_w = target_width_ref1
|
|
||||||
new_h = new_w / aspect_ratio
|
|
||||||
if new_h < target_height_ref1:
|
|
||||||
new_h = target_height_ref1
|
|
||||||
new_w = new_h * aspect_ratio
|
|
||||||
else:
|
|
||||||
aspect_ratio = w / h
|
|
||||||
tgt_aspect_ratio = target_width_ref1 / target_height_ref1
|
|
||||||
if aspect_ratio > tgt_aspect_ratio:
|
|
||||||
new_h = target_height_ref1
|
|
||||||
new_w = new_h * aspect_ratio
|
|
||||||
else:
|
|
||||||
new_w = target_width_ref1
|
|
||||||
new_h = new_w / aspect_ratio
|
|
||||||
# 使用 TVF.resize 进行图像缩放
|
|
||||||
image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
|
|
||||||
# 计算中心裁剪的参数
|
|
||||||
top = (image.shape[-2] - target_height_ref1) // 2
|
|
||||||
left = (image.shape[-1] - target_width_ref1) // 2
|
|
||||||
# 使用 TVF.crop 进行中心裁剪
|
|
||||||
image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def stitch_images(img1, img2):
|
def stitch_images(img1, img2):
|
||||||
# Resize img2 to match img1's height
|
# Resize img2 to match img1's height
|
||||||
width1, height1 = img1.size
|
width1, height1 = img1.size
|
||||||
@ -171,8 +133,6 @@ class model_factory:
|
|||||||
device="cuda"
|
device="cuda"
|
||||||
flux_dev_uso = self.name in ['flux-dev-uso']
|
flux_dev_uso = self.name in ['flux-dev-uso']
|
||||||
image_stiching = not self.name in ['flux-dev-uso'] #and False
|
image_stiching = not self.name in ['flux-dev-uso'] #and False
|
||||||
# image_refs_relative_size = 100
|
|
||||||
crop = False
|
|
||||||
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
|
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
|
||||||
ref_style_imgs = []
|
ref_style_imgs = []
|
||||||
if "I" in video_prompt_type and len(input_ref_images) > 0:
|
if "I" in video_prompt_type and len(input_ref_images) > 0:
|
||||||
@ -186,36 +146,15 @@ class model_factory:
|
|||||||
if image_stiching:
|
if image_stiching:
|
||||||
# image stiching method
|
# image stiching method
|
||||||
stiched = input_ref_images[0]
|
stiched = input_ref_images[0]
|
||||||
if "K" in video_prompt_type :
|
|
||||||
w, h = input_ref_images[0].size
|
|
||||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
|
||||||
# actual rescale will happen in prepare_kontext
|
|
||||||
for new_img in input_ref_images[1:]:
|
for new_img in input_ref_images[1:]:
|
||||||
stiched = stitch_images(stiched, new_img)
|
stiched = stitch_images(stiched, new_img)
|
||||||
input_ref_images = [stiched]
|
input_ref_images = [stiched]
|
||||||
else:
|
else:
|
||||||
first_ref = 0
|
# latents stiching with resize
|
||||||
if "K" in video_prompt_type:
|
for i in range(len(input_ref_images)):
|
||||||
# image latents tiling method
|
|
||||||
w, h = input_ref_images[0].size
|
|
||||||
if crop :
|
|
||||||
img = convert_image_to_tensor(input_ref_images[0])
|
|
||||||
img = resize_and_centercrop_image(img, height, width)
|
|
||||||
input_ref_images[0] = convert_tensor_to_image(img)
|
|
||||||
else:
|
|
||||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
|
||||||
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
|
|
||||||
first_ref = 1
|
|
||||||
|
|
||||||
for i in range(first_ref,len(input_ref_images)):
|
|
||||||
w, h = input_ref_images[i].size
|
w, h = input_ref_images[i].size
|
||||||
if crop:
|
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
|
||||||
img = convert_image_to_tensor(input_ref_images[i])
|
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||||
img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
|
|
||||||
input_ref_images[i] = convert_tensor_to_image(img)
|
|
||||||
else:
|
|
||||||
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
|
|
||||||
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
|
||||||
else:
|
else:
|
||||||
input_ref_images = None
|
input_ref_images = None
|
||||||
|
|
||||||
|
|||||||
@ -861,11 +861,6 @@ class HunyuanVideoSampler(Inference):
|
|||||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
|
||||||
else:
|
else:
|
||||||
if self.avatar:
|
if self.avatar:
|
||||||
w, h = input_ref_images.size
|
|
||||||
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
|
|
||||||
if target_width != w or target_height != h:
|
|
||||||
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
concat_dict = {'mode': 'timecat', 'bias': -1}
|
concat_dict = {'mode': 'timecat', 'bias': -1}
|
||||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -51,6 +51,23 @@ class family_handler():
|
|||||||
extra_model_def["tea_cache"] = True
|
extra_model_def["tea_cache"] = True
|
||||||
extra_model_def["mag_cache"] = True
|
extra_model_def["mag_cache"] = True
|
||||||
|
|
||||||
|
if base_model_type in ["hunyuan_custom_edit"]:
|
||||||
|
extra_model_def["guide_preprocessing"] = {
|
||||||
|
"selection": ["MV", "PMV"],
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_model_def["mask_preprocessing"] = {
|
||||||
|
"selection": ["A", "NA"],
|
||||||
|
"default" : "NA"
|
||||||
|
}
|
||||||
|
|
||||||
|
if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [("Reference Image", "I")],
|
||||||
|
"letters_filter":"I",
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
|
||||||
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
|
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
|
||||||
|
|
||||||
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
|
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
|
||||||
@ -141,6 +158,10 @@ class family_handler():
|
|||||||
|
|
||||||
return hunyuan_model, pipe
|
return hunyuan_model, pipe
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||||
ui_defaults["embedded_guidance_scale"]= 6.0
|
ui_defaults["embedded_guidance_scale"]= 6.0
|
||||||
|
|||||||
@ -300,9 +300,6 @@ class LTXV:
|
|||||||
prefix_size, height, width = input_video.shape[-3:]
|
prefix_size, height, width = input_video.shape[-3:]
|
||||||
else:
|
else:
|
||||||
if image_start != None:
|
if image_start != None:
|
||||||
frame_width, frame_height = image_start.size
|
|
||||||
if fit_into_canvas != None:
|
|
||||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
|
|
||||||
conditioning_media_paths.append(image_start.unsqueeze(1))
|
conditioning_media_paths.append(image_start.unsqueeze(1))
|
||||||
conditioning_start_frames.append(0)
|
conditioning_start_frames.append(0)
|
||||||
conditioning_control_frames.append(False)
|
conditioning_control_frames.append(False)
|
||||||
|
|||||||
@ -26,6 +26,15 @@ class family_handler():
|
|||||||
extra_model_def["sliding_window"] = True
|
extra_model_def["sliding_window"] = True
|
||||||
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
||||||
|
|
||||||
|
extra_model_def["guide_preprocessing"] = {
|
||||||
|
"selection": ["", "PV", "DV", "EV", "V"],
|
||||||
|
"labels" : { "V": "Use LTXV raw format"}
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_model_def["mask_preprocessing"] = {
|
||||||
|
"selection": ["", "A", "NA", "XA", "XNA"],
|
||||||
|
}
|
||||||
|
|
||||||
return extra_model_def
|
return extra_model_def
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
|
|||||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from shared.utils.utils import calculate_new_dimensions
|
from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
|
||||||
|
|
||||||
XLA_AVAILABLE = False
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
@ -563,6 +563,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||||
max_sequence_length: int = 512,
|
max_sequence_length: int = 512,
|
||||||
image = None,
|
image = None,
|
||||||
|
image_mask = None,
|
||||||
|
denoising_strength = 0,
|
||||||
callback=None,
|
callback=None,
|
||||||
pipeline=None,
|
pipeline=None,
|
||||||
loras_slists=None,
|
loras_slists=None,
|
||||||
@ -694,14 +696,33 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
image_width = image_width // multiple_of * multiple_of
|
image_width = image_width // multiple_of * multiple_of
|
||||||
image_height = image_height // multiple_of * multiple_of
|
image_height = image_height // multiple_of * multiple_of
|
||||||
ref_height, ref_width = 1568, 672
|
ref_height, ref_width = 1568, 672
|
||||||
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
|
|
||||||
if image_height * image_width > ref_height * ref_width:
|
|
||||||
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
|
||||||
|
|
||||||
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
|
if image_mask is None:
|
||||||
|
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
|
||||||
|
if image_height * image_width > ref_height * ref_width:
|
||||||
|
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||||
|
if (image_width,image_height) != image.size:
|
||||||
|
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
|
||||||
|
image_mask_latents = None
|
||||||
|
else:
|
||||||
|
# _, image_width, image_height = min(
|
||||||
|
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
|
||||||
|
# )
|
||||||
|
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
|
||||||
|
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||||
|
height, width = image_height, image_width
|
||||||
|
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
|
||||||
|
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
|
||||||
|
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
|
||||||
|
convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
|
||||||
|
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
|
||||||
|
|
||||||
prompt_image = image
|
prompt_image = image
|
||||||
image = self.image_processor.preprocess(image, image_height, image_width)
|
if image.size != (image_width, image_height):
|
||||||
image = image.unsqueeze(2)
|
image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
image.save("nnn.png")
|
||||||
|
image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
|
||||||
|
|
||||||
has_neg_prompt = negative_prompt is not None or (
|
has_neg_prompt = negative_prompt is not None or (
|
||||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||||
@ -744,6 +765,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
generator,
|
generator,
|
||||||
latents,
|
latents,
|
||||||
)
|
)
|
||||||
|
original_image_latents = None if image_latents is None else image_latents.clone()
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
img_shapes = [
|
img_shapes = [
|
||||||
[
|
[
|
||||||
@ -788,6 +811,15 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
negative_txt_seq_lens = (
|
negative_txt_seq_lens = (
|
||||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||||
)
|
)
|
||||||
|
morph = False
|
||||||
|
if image_mask_latents is not None and denoising_strength <= 1.:
|
||||||
|
first_step = int(len(timesteps) * (1. - denoising_strength))
|
||||||
|
if not morph:
|
||||||
|
latent_noise_factor = timesteps[first_step]/1000
|
||||||
|
latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
||||||
|
timesteps = timesteps[first_step:]
|
||||||
|
self.scheduler.timesteps = timesteps
|
||||||
|
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
|
||||||
|
|
||||||
# 6. Denoising loop
|
# 6. Denoising loop
|
||||||
self.scheduler.set_begin_index(0)
|
self.scheduler.set_begin_index(0)
|
||||||
@ -797,10 +829,15 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
update_loras_slists(self.transformer, loras_slists, updated_num_steps)
|
update_loras_slists(self.transformer, loras_slists, updated_num_steps)
|
||||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||||
|
|
||||||
|
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
|
||||||
|
latent_noise_factor = t/1000
|
||||||
|
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
|
||||||
|
|
||||||
self._current_timestep = t
|
self._current_timestep = t
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
@ -865,6 +902,12 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||||
|
if image_mask_latents is not None:
|
||||||
|
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
|
||||||
|
latent_noise_factor = next_t / 1000
|
||||||
|
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
||||||
|
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
|
||||||
|
noisy_image = None
|
||||||
|
|
||||||
if latents.dtype != latents_dtype:
|
if latents.dtype != latents_dtype:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@ -878,7 +921,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
|
|
||||||
self._current_timestep = None
|
self._current_timestep = None
|
||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
image = latents
|
output_image = latents
|
||||||
else:
|
else:
|
||||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||||
latents = latents.to(self.vae.dtype)
|
latents = latents.to(self.vae.dtype)
|
||||||
@ -891,7 +934,9 @@ class QwenImagePipeline(): #DiffusionPipeline
|
|||||||
latents.device, latents.dtype
|
latents.device, latents.dtype
|
||||||
)
|
)
|
||||||
latents = latents / latents_std + latents_mean
|
latents = latents / latents_std + latents_mean
|
||||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||||
|
if image_mask is not None:
|
||||||
|
output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt
|
||||||
|
|
||||||
|
|
||||||
return image
|
return output_image
|
||||||
|
|||||||
@ -9,7 +9,7 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
|
|||||||
class family_handler():
|
class family_handler():
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_model_def(base_model_type, model_def):
|
def query_model_def(base_model_type, model_def):
|
||||||
model_def_output = {
|
extra_model_def = {
|
||||||
"image_outputs" : True,
|
"image_outputs" : True,
|
||||||
"sample_solvers":[
|
"sample_solvers":[
|
||||||
("Default", "default"),
|
("Default", "default"),
|
||||||
@ -18,8 +18,18 @@ class family_handler():
|
|||||||
"lock_image_refs_ratios": True,
|
"lock_image_refs_ratios": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if base_model_type in ["qwen_image_edit_20B"]:
|
||||||
|
extra_model_def["inpaint_support"] = True
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [
|
||||||
|
("None", ""),
|
||||||
|
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
|
||||||
|
("Conditional Images are People / Objects", "I"),
|
||||||
|
],
|
||||||
|
"letters_filter": "KI",
|
||||||
|
}
|
||||||
|
|
||||||
return model_def_output
|
return extra_model_def
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_supported_types():
|
def query_supported_types():
|
||||||
@ -75,14 +85,18 @@ class family_handler():
|
|||||||
if ui_defaults.get("sample_solver", "") == "":
|
if ui_defaults.get("sample_solver", "") == "":
|
||||||
ui_defaults["sample_solver"] = "default"
|
ui_defaults["sample_solver"] = "default"
|
||||||
|
|
||||||
|
if settings_version < 2.32:
|
||||||
|
ui_defaults["denoising_strength"] = 1.
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"guidance_scale": 4,
|
"guidance_scale": 4,
|
||||||
"sample_solver": "default",
|
"sample_solver": "default",
|
||||||
})
|
})
|
||||||
if model_def.get("reference_image", False):
|
if base_model_type in ["qwen_image_edit_20B"]:
|
||||||
ui_defaults.update({
|
ui_defaults.update({
|
||||||
"video_prompt_type": "KI",
|
"video_prompt_type": "KI",
|
||||||
|
"denoising_strength" : 1.,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -103,6 +103,8 @@ class model_factory():
|
|||||||
n_prompt = None,
|
n_prompt = None,
|
||||||
sampling_steps: int = 20,
|
sampling_steps: int = 20,
|
||||||
input_ref_images = None,
|
input_ref_images = None,
|
||||||
|
image_guide= None,
|
||||||
|
image_mask= None,
|
||||||
width= 832,
|
width= 832,
|
||||||
height=480,
|
height=480,
|
||||||
guide_scale: float = 4,
|
guide_scale: float = 4,
|
||||||
@ -114,6 +116,7 @@ class model_factory():
|
|||||||
VAE_tile_size = None,
|
VAE_tile_size = None,
|
||||||
joint_pass = True,
|
joint_pass = True,
|
||||||
sample_solver='default',
|
sample_solver='default',
|
||||||
|
denoising_strength = 1.,
|
||||||
**bbargs
|
**bbargs
|
||||||
):
|
):
|
||||||
# Generate with different aspect ratios
|
# Generate with different aspect ratios
|
||||||
@ -174,8 +177,9 @@ class model_factory():
|
|||||||
|
|
||||||
if n_prompt is None or len(n_prompt) == 0:
|
if n_prompt is None or len(n_prompt) == 0:
|
||||||
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
||||||
|
if image_guide is not None:
|
||||||
if input_ref_images is not None:
|
input_ref_images = [image_guide]
|
||||||
|
elif input_ref_images is not None:
|
||||||
# image stiching method
|
# image stiching method
|
||||||
stiched = input_ref_images[0]
|
stiched = input_ref_images[0]
|
||||||
if "K" in video_prompt_type :
|
if "K" in video_prompt_type :
|
||||||
@ -190,6 +194,7 @@ class model_factory():
|
|||||||
prompt=input_prompt,
|
prompt=input_prompt,
|
||||||
negative_prompt=n_prompt,
|
negative_prompt=n_prompt,
|
||||||
image = input_ref_images,
|
image = input_ref_images,
|
||||||
|
image_mask = image_mask,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
num_inference_steps=sampling_steps,
|
num_inference_steps=sampling_steps,
|
||||||
@ -199,6 +204,7 @@ class model_factory():
|
|||||||
pipeline=self,
|
pipeline=self,
|
||||||
loras_slists=loras_slists,
|
loras_slists=loras_slists,
|
||||||
joint_pass = joint_pass,
|
joint_pass = joint_pass,
|
||||||
|
denoising_strength=denoising_strength,
|
||||||
generator=torch.Generator(device="cuda").manual_seed(seed)
|
generator=torch.Generator(device="cuda").manual_seed(seed)
|
||||||
)
|
)
|
||||||
if image is None: return None
|
if image is None: return None
|
||||||
|
|||||||
@ -261,7 +261,7 @@ class WanAny2V:
|
|||||||
def vace_latent(self, z, m):
|
def vace_latent(self, z, m):
|
||||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||||
|
|
||||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
|
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
|
||||||
from shared.utils.utils import save_image
|
from shared.utils.utils import save_image
|
||||||
ref_width, ref_height = ref_img.size
|
ref_width, ref_height = ref_img.size
|
||||||
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
||||||
@ -270,18 +270,23 @@ class WanAny2V:
|
|||||||
else:
|
else:
|
||||||
if outpainting_dims != None:
|
if outpainting_dims != None:
|
||||||
final_height, final_width = image_size
|
final_height, final_width = image_size
|
||||||
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
|
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
||||||
else:
|
else:
|
||||||
canvas_height, canvas_width = image_size
|
canvas_height, canvas_width = image_size
|
||||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
if full_frame:
|
||||||
new_height = int(ref_height * scale)
|
|
||||||
new_width = int(ref_width * scale)
|
|
||||||
if fill_max and (canvas_height - new_height) < 16:
|
|
||||||
new_height = canvas_height
|
new_height = canvas_height
|
||||||
if fill_max and (canvas_width - new_width) < 16:
|
|
||||||
new_width = canvas_width
|
new_width = canvas_width
|
||||||
top = (canvas_height - new_height) // 2
|
top = left = 0
|
||||||
left = (canvas_width - new_width) // 2
|
else:
|
||||||
|
# if fill_max and (canvas_height - new_height) < 16:
|
||||||
|
# new_height = canvas_height
|
||||||
|
# if fill_max and (canvas_width - new_width) < 16:
|
||||||
|
# new_width = canvas_width
|
||||||
|
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||||
|
new_height = int(ref_height * scale)
|
||||||
|
new_width = int(ref_width * scale)
|
||||||
|
top = (canvas_height - new_height) // 2
|
||||||
|
left = (canvas_width - new_width) // 2
|
||||||
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||||
if outpainting_dims != None:
|
if outpainting_dims != None:
|
||||||
@ -302,7 +307,7 @@ class WanAny2V:
|
|||||||
canvas = canvas.to(device)
|
canvas = canvas.to(device)
|
||||||
return ref_img.to(device), canvas
|
return ref_img.to(device), canvas
|
||||||
|
|
||||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||||
image_sizes = []
|
image_sizes = []
|
||||||
trim_video_guide = len(keep_video_guide_frames)
|
trim_video_guide = len(keep_video_guide_frames)
|
||||||
def conv_tensor(t, device):
|
def conv_tensor(t, device):
|
||||||
@ -533,22 +538,16 @@ class WanAny2V:
|
|||||||
any_end_frame = False
|
any_end_frame = False
|
||||||
if image_start is None:
|
if image_start is None:
|
||||||
if infinitetalk:
|
if infinitetalk:
|
||||||
|
new_shot = "Q" in video_prompt_type
|
||||||
if input_frames is not None:
|
if input_frames is not None:
|
||||||
image_ref = input_frames[:, 0]
|
image_ref = input_frames[:, 0]
|
||||||
if input_video is None: input_video = input_frames[:, 0:1]
|
|
||||||
new_shot = "Q" in video_prompt_type
|
|
||||||
else:
|
else:
|
||||||
if pre_video_frame is None:
|
if input_ref_images is None:
|
||||||
new_shot = True
|
if pre_video_frame is None: raise Exception("Missing Reference Image")
|
||||||
else:
|
input_ref_images = [pre_video_frame]
|
||||||
if input_ref_images is None:
|
|
||||||
input_ref_images, new_shot = [pre_video_frame], False
|
|
||||||
else:
|
|
||||||
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
|
|
||||||
if input_ref_images is None: raise Exception("Missing Reference Image")
|
|
||||||
new_shot = new_shot and window_no <= len(input_ref_images)
|
new_shot = new_shot and window_no <= len(input_ref_images)
|
||||||
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
|
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
|
||||||
if new_shot:
|
if new_shot or input_video is None:
|
||||||
input_video = image_ref.unsqueeze(1)
|
input_video = image_ref.unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
|
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
|
||||||
|
|||||||
@ -35,7 +35,7 @@ class family_handler():
|
|||||||
"label" : "Generation Type"
|
"label" : "Generation Type"
|
||||||
}
|
}
|
||||||
|
|
||||||
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
extra_model_def["image_prompt_types_allowed"] = "TSV"
|
||||||
|
|
||||||
|
|
||||||
return extra_model_def
|
return extra_model_def
|
||||||
|
|||||||
@ -110,19 +110,79 @@ class family_handler():
|
|||||||
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
|
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
|
||||||
"mag_cache" : True,
|
"mag_cache" : True,
|
||||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||||
|
"convert_image_guide_to_video" : True,
|
||||||
"sample_solvers":[
|
"sample_solvers":[
|
||||||
("unipc", "unipc"),
|
("unipc", "unipc"),
|
||||||
("euler", "euler"),
|
("euler", "euler"),
|
||||||
("dpm++", "dpm++"),
|
("dpm++", "dpm++"),
|
||||||
("flowmatch causvid", "causvid"), ]
|
("flowmatch causvid", "causvid"), ]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if base_model_type in ["t2v"]:
|
||||||
|
extra_model_def["guide_custom_choices"] = {
|
||||||
|
"choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")],
|
||||||
|
"default": "",
|
||||||
|
"letters_filter": "GUV",
|
||||||
|
"label": "Video to Video"
|
||||||
|
}
|
||||||
|
|
||||||
if base_model_type in ["infinitetalk"]:
|
if base_model_type in ["infinitetalk"]:
|
||||||
extra_model_def["no_background_removal"] = True
|
extra_model_def["no_background_removal"] = True
|
||||||
# extra_model_def["at_least_one_image_ref_needed"] = True
|
extra_model_def["all_image_refs_are_background_ref"] = True
|
||||||
|
extra_model_def["guide_custom_choices"] = {
|
||||||
|
"choices":[
|
||||||
|
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
|
||||||
|
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
|
||||||
|
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
|
||||||
|
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
|
||||||
|
("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
|
||||||
|
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
|
||||||
|
],
|
||||||
|
"default": "KI",
|
||||||
|
"letters_filter": "RGUVQKI",
|
||||||
|
"label": "Video to Video",
|
||||||
|
"show_label" : False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# extra_model_def["at_least_one_image_ref_needed"] = True
|
||||||
|
if vace_class:
|
||||||
|
extra_model_def["guide_preprocessing"] = {
|
||||||
|
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
|
||||||
|
"labels" : { "V": "Use Vace raw format"}
|
||||||
|
}
|
||||||
|
extra_model_def["mask_preprocessing"] = {
|
||||||
|
"selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
|
||||||
|
}
|
||||||
|
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [("None", ""),
|
||||||
|
("Inject only People / Objects", "I"),
|
||||||
|
("Inject Landscape and then People / Objects", "KI"),
|
||||||
|
("Inject Frames and then People / Objects", "FI"),
|
||||||
|
],
|
||||||
|
"letters_filter": "KFI",
|
||||||
|
}
|
||||||
|
|
||||||
if base_model_type in ["standin"] or vace_class:
|
|
||||||
extra_model_def["lock_image_refs_ratios"] = True
|
extra_model_def["lock_image_refs_ratios"] = True
|
||||||
|
|
||||||
|
if base_model_type in ["standin"]:
|
||||||
|
extra_model_def["lock_image_refs_ratios"] = True
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [
|
||||||
|
("No Reference Image", ""),
|
||||||
|
("Reference Image is a Person Face", "I"),
|
||||||
|
],
|
||||||
|
"letters_filter":"I",
|
||||||
|
}
|
||||||
|
|
||||||
|
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
|
||||||
|
extra_model_def["image_ref_choices"] = {
|
||||||
|
"choices": [("Reference Image", "I")],
|
||||||
|
"letters_filter":"I",
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
|
||||||
if base_model_type in ["recam_1.3B"]:
|
if base_model_type in ["recam_1.3B"]:
|
||||||
extra_model_def["keep_frames_video_guide_not_supported"] = True
|
extra_model_def["keep_frames_video_guide_not_supported"] = True
|
||||||
extra_model_def["model_modes"] = {
|
extra_model_def["model_modes"] = {
|
||||||
@ -141,10 +201,18 @@ class family_handler():
|
|||||||
"default": 1,
|
"default": 1,
|
||||||
"label" : "Camera Movement Type"
|
"label" : "Camera Movement Type"
|
||||||
}
|
}
|
||||||
|
extra_model_def["guide_preprocessing"] = {
|
||||||
|
"selection": ["UV"],
|
||||||
|
"labels" : { "UV": "Control Video"},
|
||||||
|
"visible" : False,
|
||||||
|
}
|
||||||
|
|
||||||
if vace_class or base_model_type in ["infinitetalk"]:
|
if vace_class or base_model_type in ["infinitetalk"]:
|
||||||
image_prompt_types_allowed = "TVL"
|
image_prompt_types_allowed = "TVL"
|
||||||
elif base_model_type in ["ti2v_2_2"]:
|
elif base_model_type in ["ti2v_2_2"]:
|
||||||
image_prompt_types_allowed = "TSEVL"
|
image_prompt_types_allowed = "TSVL"
|
||||||
|
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
|
||||||
|
image_prompt_types_allowed = "SVL"
|
||||||
elif i2v:
|
elif i2v:
|
||||||
image_prompt_types_allowed = "SEVL"
|
image_prompt_types_allowed = "SEVL"
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -7,7 +7,6 @@ import psutil
|
|||||||
# import ffmpeg
|
# import ffmpeg
|
||||||
import imageio
|
import imageio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -33,6 +32,8 @@ model_in_GPU = False
|
|||||||
matanyone_in_GPU = False
|
matanyone_in_GPU = False
|
||||||
bfloat16_supported = False
|
bfloat16_supported = False
|
||||||
# SAM generator
|
# SAM generator
|
||||||
|
import copy
|
||||||
|
|
||||||
class MaskGenerator():
|
class MaskGenerator():
|
||||||
def __init__(self, sam_checkpoint, device):
|
def __init__(self, sam_checkpoint, device):
|
||||||
global args_device
|
global args_device
|
||||||
@ -89,6 +90,7 @@ def get_frames_from_image(image_input, image_state):
|
|||||||
"last_frame_numer": 0,
|
"last_frame_numer": 0,
|
||||||
"fps": None
|
"fps": None
|
||||||
}
|
}
|
||||||
|
|
||||||
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
||||||
set_image_encoder_patch()
|
set_image_encoder_patch()
|
||||||
select_SAM()
|
select_SAM()
|
||||||
@ -717,27 +719,33 @@ def load_unload_models(selected):
|
|||||||
def get_vmc_event_handler():
|
def get_vmc_event_handler():
|
||||||
return load_unload_models
|
return load_unload_models
|
||||||
|
|
||||||
def export_to_vace_video_input(foreground_video_output):
|
|
||||||
gr.Info("Masked Video Input transferred to Vace For Inpainting")
|
|
||||||
return "V#" + str(time.time()), foreground_video_output
|
|
||||||
|
|
||||||
|
def export_image(state, image_output):
|
||||||
def export_image(image_refs, image_output):
|
ui_settings = get_current_model_settings(state)
|
||||||
gr.Info("Masked Image transferred to Current Video")
|
image_refs = ui_settings["image_refs"]
|
||||||
if image_refs == None:
|
if image_refs == None:
|
||||||
image_refs =[]
|
image_refs =[]
|
||||||
image_refs.append( image_output)
|
image_refs.append( image_output)
|
||||||
return image_refs
|
ui_settings["image_refs"] = image_refs
|
||||||
|
gr.Info("Masked Image transferred to Current Image Generator")
|
||||||
|
return time.time()
|
||||||
|
|
||||||
def export_image_mask(image_input, image_mask):
|
def export_image_mask(state, image_input, image_mask):
|
||||||
gr.Info("Input Image & Mask transferred to Current Video")
|
ui_settings = get_current_model_settings(state)
|
||||||
return Image.fromarray(image_input), image_mask
|
ui_settings["image_guide"] = Image.fromarray(image_input)
|
||||||
|
ui_settings["image_mask"] = image_mask
|
||||||
|
|
||||||
|
gr.Info("Input Image & Mask transferred to Current Image Generator")
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
def export_to_current_video_engine( foreground_video_output, alpha_video_output):
|
def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
|
||||||
|
ui_settings = get_current_model_settings(state)
|
||||||
|
ui_settings["video_guide"] = foreground_video_output
|
||||||
|
ui_settings["video_mask"] = alpha_video_output
|
||||||
|
|
||||||
gr.Info("Original Video and Full Mask have been transferred")
|
gr.Info("Original Video and Full Mask have been transferred")
|
||||||
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
return time.time()
|
||||||
return foreground_video_output, alpha_video_output
|
|
||||||
|
|
||||||
|
|
||||||
def teleport_to_video_tab(tab_state):
|
def teleport_to_video_tab(tab_state):
|
||||||
@ -746,9 +754,10 @@ def teleport_to_video_tab(tab_state):
|
|||||||
return gr.Tabs(selected="video_gen")
|
return gr.Tabs(selected="video_gen")
|
||||||
|
|
||||||
|
|
||||||
def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
|
def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
|
||||||
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
||||||
global image_output_codec, video_output_codec
|
global image_output_codec, video_output_codec, get_current_model_settings
|
||||||
|
get_current_model_settings = get_current_model_settings_fn
|
||||||
|
|
||||||
image_output_codec = server_config.get("image_output_codec", None)
|
image_output_codec = server_config.get("image_output_codec", None)
|
||||||
video_output_codec = server_config.get("video_output_codec", None)
|
video_output_codec = server_config.get("video_output_codec", None)
|
||||||
@ -871,7 +880,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
|||||||
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
|
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
|
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
|
||||||
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
|
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
|
||||||
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
|
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
|
||||||
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
|
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -892,7 +901,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
|||||||
with gr.Row(visible= True):
|
with gr.Row(visible= True):
|
||||||
export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
|
export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
|
||||||
|
|
||||||
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger,
|
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||||
|
|
||||||
|
|
||||||
@ -1089,9 +1098,9 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
|||||||
# with gr.Column(scale=2, visible= True):
|
# with gr.Column(scale=2, visible= True):
|
||||||
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
|
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
|
||||||
|
|
||||||
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
|
export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||||
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
|
export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||||
|
|
||||||
# first step: get the image information
|
# first step: get the image information
|
||||||
@ -1148,5 +1157,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
|||||||
outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
|
outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
nada = gr.State({})
|
||||||
|
# clear input
|
||||||
|
gr.on(
|
||||||
|
triggers=[image_input.clear], #image_input.change,
|
||||||
|
fn=restart,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[
|
||||||
|
image_state,
|
||||||
|
interactive_state,
|
||||||
|
click_state,
|
||||||
|
foreground_image_output, alpha_image_output,
|
||||||
|
template_frame,
|
||||||
|
image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click,
|
||||||
|
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
|
||||||
|
],
|
||||||
|
queue=False,
|
||||||
|
show_progress=False)
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ librosa==0.11.0
|
|||||||
speechbrain==1.0.3
|
speechbrain==1.0.3
|
||||||
|
|
||||||
# UI & interaction
|
# UI & interaction
|
||||||
gradio==5.23.0
|
gradio==5.29.0
|
||||||
dashscope
|
dashscope
|
||||||
loguru
|
loguru
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import PIL
|
import PIL
|
||||||
|
import time
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
FilePath = str
|
FilePath = str
|
||||||
@ -20,6 +21,9 @@ def get_list( objs):
|
|||||||
return []
|
return []
|
||||||
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
|
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
|
||||||
|
|
||||||
|
def record_last_action(st, last_action):
|
||||||
|
st["last_action"] = last_action
|
||||||
|
st["last_time"] = time.time()
|
||||||
class AdvancedMediaGallery:
|
class AdvancedMediaGallery:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -60,9 +64,10 @@ class AdvancedMediaGallery:
|
|||||||
self.state: Optional[gr.State] = None
|
self.state: Optional[gr.State] = None
|
||||||
self._initial_state: Dict[str, Any] = {
|
self._initial_state: Dict[str, Any] = {
|
||||||
"items": items,
|
"items": items,
|
||||||
"selected": (len(items) - 1) if items else None,
|
"selected": (len(items) - 1) if items else 0, # None,
|
||||||
"single": bool(single_image_mode),
|
"single": bool(single_image_mode),
|
||||||
"mode": self.media_mode,
|
"mode": self.media_mode,
|
||||||
|
"last_action": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# ---------------- helpers ----------------
|
# ---------------- helpers ----------------
|
||||||
@ -210,6 +215,13 @@ class AdvancedMediaGallery:
|
|||||||
|
|
||||||
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
|
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
|
||||||
# Mirror the selected index into state and the gallery (server-side selected_index)
|
# Mirror the selected index into state and the gallery (server-side selected_index)
|
||||||
|
|
||||||
|
st = get_state(state)
|
||||||
|
last_time = st.get("last_time", None)
|
||||||
|
if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
|
||||||
|
# print(f"ignored:{time.time()}, real {st['selected']}")
|
||||||
|
return gr.update(selected_index=st["selected"]), st
|
||||||
|
|
||||||
idx = None
|
idx = None
|
||||||
if evt is not None and hasattr(evt, "index"):
|
if evt is not None and hasattr(evt, "index"):
|
||||||
ix = evt.index
|
ix = evt.index
|
||||||
@ -220,17 +232,28 @@ class AdvancedMediaGallery:
|
|||||||
idx = ix[0] * max(1, int(self.columns)) + ix[1]
|
idx = ix[0] * max(1, int(self.columns)) + ix[1]
|
||||||
else:
|
else:
|
||||||
idx = ix[0]
|
idx = ix[0]
|
||||||
st = get_state(state)
|
|
||||||
n = len(get_list(gallery))
|
n = len(get_list(gallery))
|
||||||
sel = idx if (idx is not None and 0 <= idx < n) else None
|
sel = idx if (idx is not None and 0 <= idx < n) else None
|
||||||
|
# print(f"image selected evt index:{sel}/{evt.selected}")
|
||||||
st["selected"] = sel
|
st["selected"] = sel
|
||||||
# return gr.update(selected_index=sel), st
|
return gr.update(), st
|
||||||
# return gr.update(), st
|
|
||||||
return st
|
def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
|
||||||
|
# Fires when users upload via the Gallery itself.
|
||||||
|
# items_filtered = self._filter_items_by_mode(list(value or []))
|
||||||
|
items_filtered = list(value or [])
|
||||||
|
st = get_state(state)
|
||||||
|
new_items = self._paths_from_payload(items_filtered)
|
||||||
|
st["items"] = new_items
|
||||||
|
new_sel = len(new_items) - 1
|
||||||
|
st["selected"] = new_sel
|
||||||
|
record_last_action(st,"add")
|
||||||
|
return gr.update(selected_index=new_sel), st
|
||||||
|
|
||||||
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
|
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
|
||||||
# Fires when users add/drag/drop/delete via the Gallery itself.
|
# Fires when users add/drag/drop/delete via the Gallery itself.
|
||||||
items_filtered = self._filter_items_by_mode(list(value or []))
|
# items_filtered = self._filter_items_by_mode(list(value or []))
|
||||||
|
items_filtered = list(value or [])
|
||||||
st = get_state(state)
|
st = get_state(state)
|
||||||
st["items"] = items_filtered
|
st["items"] = items_filtered
|
||||||
# Keep selection if still valid, else default to last
|
# Keep selection if still valid, else default to last
|
||||||
@ -240,10 +263,9 @@ class AdvancedMediaGallery:
|
|||||||
else:
|
else:
|
||||||
new_sel = old_sel
|
new_sel = old_sel
|
||||||
st["selected"] = new_sel
|
st["selected"] = new_sel
|
||||||
# return gr.update(value=items_filtered, selected_index=new_sel), st
|
st["last_action"] ="gallery_change"
|
||||||
# return gr.update(value=items_filtered), st
|
# print(f"gallery change: set sel {new_sel}")
|
||||||
|
return gr.update(selected_index=new_sel), st
|
||||||
return gr.update(), st
|
|
||||||
|
|
||||||
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
|
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
|
||||||
"""
|
"""
|
||||||
@ -252,7 +274,8 @@ class AdvancedMediaGallery:
|
|||||||
and re-selects the last inserted item.
|
and re-selects the last inserted item.
|
||||||
"""
|
"""
|
||||||
# New items (respect image/video mode)
|
# New items (respect image/video mode)
|
||||||
new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
|
# new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
|
||||||
|
new_items = self._paths_from_payload(files_payload)
|
||||||
|
|
||||||
st = get_state(state)
|
st = get_state(state)
|
||||||
cur: List[Any] = get_list(gallery)
|
cur: List[Any] = get_list(gallery)
|
||||||
@ -298,30 +321,6 @@ class AdvancedMediaGallery:
|
|||||||
if k is not None:
|
if k is not None:
|
||||||
seen_new.add(k)
|
seen_new.add(k)
|
||||||
|
|
||||||
# Remove any existing occurrences of the incoming items from current list,
|
|
||||||
# BUT keep the currently selected item even if it's also in incoming.
|
|
||||||
cur_clean: List[Any] = []
|
|
||||||
# sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
|
|
||||||
# for idx, it in enumerate(cur):
|
|
||||||
# k = key_of(it)
|
|
||||||
# if it is sel_item:
|
|
||||||
# cur_clean.append(it)
|
|
||||||
# continue
|
|
||||||
# if k is not None and k in seen_new:
|
|
||||||
# continue # drop duplicate; we'll reinsert at the target spot
|
|
||||||
# cur_clean.append(it)
|
|
||||||
|
|
||||||
# # Compute insertion position: right AFTER the (possibly shifted) selected item
|
|
||||||
# if sel_item is not None:
|
|
||||||
# # find sel_item's new index in cur_clean
|
|
||||||
# try:
|
|
||||||
# pos_sel = cur_clean.index(sel_item)
|
|
||||||
# except ValueError:
|
|
||||||
# # Shouldn't happen, but fall back to end
|
|
||||||
# pos_sel = len(cur_clean) - 1
|
|
||||||
# insert_pos = pos_sel + 1
|
|
||||||
# else:
|
|
||||||
# insert_pos = len(cur_clean) # no selection -> append at end
|
|
||||||
insert_pos = min(sel, len(cur) -1)
|
insert_pos = min(sel, len(cur) -1)
|
||||||
cur_clean = cur
|
cur_clean = cur
|
||||||
# Build final list and selection
|
# Build final list and selection
|
||||||
@ -330,6 +329,8 @@ class AdvancedMediaGallery:
|
|||||||
|
|
||||||
st["items"] = merged
|
st["items"] = merged
|
||||||
st["selected"] = new_sel
|
st["selected"] = new_sel
|
||||||
|
record_last_action(st,"add")
|
||||||
|
# print(f"gallery add: set sel {new_sel}")
|
||||||
return gr.update(value=merged, selected_index=new_sel), st
|
return gr.update(value=merged, selected_index=new_sel), st
|
||||||
|
|
||||||
def _on_remove(self, state: Dict[str, Any], gallery) :
|
def _on_remove(self, state: Dict[str, Any], gallery) :
|
||||||
@ -342,8 +343,9 @@ class AdvancedMediaGallery:
|
|||||||
return gr.update(value=[], selected_index=None), st
|
return gr.update(value=[], selected_index=None), st
|
||||||
new_sel = min(sel, len(items) - 1)
|
new_sel = min(sel, len(items) - 1)
|
||||||
st["items"] = items; st["selected"] = new_sel
|
st["items"] = items; st["selected"] = new_sel
|
||||||
# return gr.update(value=items, selected_index=new_sel), st
|
record_last_action(st,"remove")
|
||||||
return gr.update(value=items), st
|
# print(f"gallery del: new sel {new_sel}")
|
||||||
|
return gr.update(value=items, selected_index=new_sel), st
|
||||||
|
|
||||||
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
|
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
|
||||||
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
|
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
|
||||||
@ -354,11 +356,15 @@ class AdvancedMediaGallery:
|
|||||||
return gr.update(value=items, selected_index=sel), st
|
return gr.update(value=items, selected_index=sel), st
|
||||||
items[sel], items[j] = items[j], items[sel]
|
items[sel], items[j] = items[j], items[sel]
|
||||||
st["items"] = items; st["selected"] = j
|
st["items"] = items; st["selected"] = j
|
||||||
|
record_last_action(st,"move")
|
||||||
|
# print(f"gallery move: set sel {j}")
|
||||||
return gr.update(value=items, selected_index=j), st
|
return gr.update(value=items, selected_index=j), st
|
||||||
|
|
||||||
def _on_clear(self, state: Dict[str, Any]) :
|
def _on_clear(self, state: Dict[str, Any]) :
|
||||||
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
|
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
|
||||||
return gr.update(value=[], selected_index=0), st
|
record_last_action(st,"clear")
|
||||||
|
# print(f"Clear all")
|
||||||
|
return gr.update(value=[], selected_index=None), st
|
||||||
|
|
||||||
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
|
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
|
||||||
st = get_state(state); st["single"] = bool(to_single)
|
st = get_state(state); st["single"] = bool(to_single)
|
||||||
@ -382,30 +388,38 @@ class AdvancedMediaGallery:
|
|||||||
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
|
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
|
||||||
if parent is not None:
|
if parent is not None:
|
||||||
with parent:
|
with parent:
|
||||||
col = self._build_ui()
|
col = self._build_ui(update_form)
|
||||||
else:
|
else:
|
||||||
col = self._build_ui()
|
col = self._build_ui(update_form)
|
||||||
if not update_form:
|
if not update_form:
|
||||||
self._wire_events()
|
self._wire_events()
|
||||||
return col
|
return col
|
||||||
|
|
||||||
def _build_ui(self) -> gr.Column:
|
def _build_ui(self, update = False) -> gr.Column:
|
||||||
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
|
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
|
||||||
self.container = col
|
self.container = col
|
||||||
|
|
||||||
self.state = gr.State(dict(self._initial_state))
|
self.state = gr.State(dict(self._initial_state))
|
||||||
|
|
||||||
self.gallery = gr.Gallery(
|
if update:
|
||||||
label=self.label,
|
self.gallery = gr.update(
|
||||||
value=self._initial_state["items"],
|
value=self._initial_state["items"],
|
||||||
height=self.height,
|
selected_index=self._initial_state["selected"], # server-side selection
|
||||||
columns=self.columns,
|
label=self.label,
|
||||||
show_label=self.show_label,
|
show_label=self.show_label,
|
||||||
preview= True,
|
)
|
||||||
# type="pil",
|
else:
|
||||||
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
|
self.gallery = gr.Gallery(
|
||||||
selected_index=self._initial_state["selected"], # server-side selection
|
value=self._initial_state["items"],
|
||||||
)
|
label=self.label,
|
||||||
|
height=self.height,
|
||||||
|
columns=self.columns,
|
||||||
|
show_label=self.show_label,
|
||||||
|
preview= True,
|
||||||
|
# type="pil", # very slow
|
||||||
|
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
|
||||||
|
selected_index=self._initial_state["selected"], # server-side selection
|
||||||
|
)
|
||||||
|
|
||||||
# One-line controls
|
# One-line controls
|
||||||
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
|
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
|
||||||
@ -418,10 +432,10 @@ class AdvancedMediaGallery:
|
|||||||
size="sm",
|
size="sm",
|
||||||
min_width=1,
|
min_width=1,
|
||||||
)
|
)
|
||||||
self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
|
self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
|
||||||
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
|
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||||
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
|
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||||
self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
|
self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||||
|
|
||||||
return col
|
return col
|
||||||
|
|
||||||
@ -430,14 +444,24 @@ class AdvancedMediaGallery:
|
|||||||
self.gallery.select(
|
self.gallery.select(
|
||||||
self._on_select,
|
self._on_select,
|
||||||
inputs=[self.state, self.gallery],
|
inputs=[self.state, self.gallery],
|
||||||
outputs=[self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
|
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
|
||||||
self.gallery.change(
|
self.gallery.upload(
|
||||||
|
self._on_upload,
|
||||||
|
inputs=[self.gallery, self.state],
|
||||||
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
|
||||||
|
self.gallery.upload(
|
||||||
self._on_gallery_change,
|
self._on_gallery_change,
|
||||||
inputs=[self.gallery, self.state],
|
inputs=[self.gallery, self.state],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add via UploadButton
|
# Add via UploadButton
|
||||||
@ -445,6 +469,7 @@ class AdvancedMediaGallery:
|
|||||||
self._on_add,
|
self._on_add,
|
||||||
inputs=[self.upload_btn, self.state, self.gallery],
|
inputs=[self.upload_btn, self.state, self.gallery],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove selected
|
# Remove selected
|
||||||
@ -452,6 +477,7 @@ class AdvancedMediaGallery:
|
|||||||
self._on_remove,
|
self._on_remove,
|
||||||
inputs=[self.state, self.gallery],
|
inputs=[self.state, self.gallery],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reorder using selected index, keep same item selected
|
# Reorder using selected index, keep same item selected
|
||||||
@ -459,11 +485,13 @@ class AdvancedMediaGallery:
|
|||||||
lambda st, gallery: self._on_move(-1, st, gallery),
|
lambda st, gallery: self._on_move(-1, st, gallery),
|
||||||
inputs=[self.state, self.gallery],
|
inputs=[self.state, self.gallery],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
self.btn_right.click(
|
self.btn_right.click(
|
||||||
lambda st, gallery: self._on_move(+1, st, gallery),
|
lambda st, gallery: self._on_move(+1, st, gallery),
|
||||||
inputs=[self.state, self.gallery],
|
inputs=[self.state, self.gallery],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clear all
|
# Clear all
|
||||||
@ -471,6 +499,7 @@ class AdvancedMediaGallery:
|
|||||||
self._on_clear,
|
self._on_clear,
|
||||||
inputs=[self.state],
|
inputs=[self.state],
|
||||||
outputs=[self.gallery, self.state],
|
outputs=[self.gallery, self.state],
|
||||||
|
trigger_mode="always_last",
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---------------- public API ----------------
|
# ---------------- public API ----------------
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import tempfile
|
|||||||
import subprocess
|
import subprocess
|
||||||
import json
|
import json
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||||
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -207,30 +208,62 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
|
|||||||
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
|
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
|
||||||
return height, width, margin_top, margin_left
|
return height, width, margin_top, margin_left
|
||||||
|
|
||||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
|
def rescale_and_crop(img, w, h):
|
||||||
if fit_into_canvas == None:
|
ow, oh = img.size
|
||||||
|
target_ratio = w / h
|
||||||
|
orig_ratio = ow / oh
|
||||||
|
|
||||||
|
if orig_ratio > target_ratio:
|
||||||
|
# Crop width first
|
||||||
|
nw = int(oh * target_ratio)
|
||||||
|
img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
|
||||||
|
else:
|
||||||
|
# Crop height first
|
||||||
|
nh = int(ow / target_ratio)
|
||||||
|
img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
|
||||||
|
|
||||||
|
return img.resize((w, h), Image.LANCZOS)
|
||||||
|
|
||||||
|
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
|
||||||
|
if fit_into_canvas == None or fit_into_canvas == 2:
|
||||||
# return image_height, image_width
|
# return image_height, image_width
|
||||||
return canvas_height, canvas_width
|
return canvas_height, canvas_width
|
||||||
if fit_into_canvas:
|
if fit_into_canvas == 1:
|
||||||
scale1 = min(canvas_height / image_height, canvas_width / image_width)
|
scale1 = min(canvas_height / image_height, canvas_width / image_width)
|
||||||
scale2 = min(canvas_width / image_height, canvas_height / image_width)
|
scale2 = min(canvas_width / image_height, canvas_height / image_width)
|
||||||
scale = max(scale1, scale2)
|
scale = max(scale1, scale2)
|
||||||
else:
|
else: #0 or #2 (crop)
|
||||||
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
|
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
|
||||||
|
|
||||||
new_height = round( image_height * scale / block_size) * block_size
|
new_height = round( image_height * scale / block_size) * block_size
|
||||||
new_width = round( image_width * scale / block_size) * block_size
|
new_width = round( image_width * scale / block_size) * block_size
|
||||||
return new_height, new_width
|
return new_height, new_width
|
||||||
|
|
||||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
|
def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
|
||||||
|
if fit_crop:
|
||||||
|
image = rescale_and_crop(image, canvas_width, canvas_height)
|
||||||
|
new_width, new_height = image.size
|
||||||
|
else:
|
||||||
|
image_width, image_height = image.size
|
||||||
|
new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
|
||||||
|
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||||
|
return image, new_height, new_width
|
||||||
|
|
||||||
|
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
|
||||||
if rm_background:
|
if rm_background:
|
||||||
session = new_session()
|
session = new_session()
|
||||||
|
|
||||||
output_list =[]
|
output_list =[]
|
||||||
for i, img in enumerate(img_list):
|
for i, img in enumerate(img_list):
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
|
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||||
if fit_into_canvas:
|
if outpainting_dims is not None:
|
||||||
|
resized_image =img
|
||||||
|
elif img.size != (budget_width, budget_height):
|
||||||
|
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
|
||||||
|
else:
|
||||||
|
resized_image =img
|
||||||
|
elif fit_into_canvas == 1:
|
||||||
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
|
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
|
||||||
scale = min(budget_height / height, budget_width / width)
|
scale = min(budget_height / height, budget_width / width)
|
||||||
new_height = int(height * scale)
|
new_height = int(height * scale)
|
||||||
@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
|||||||
resized_image = Image.fromarray(white_canvas)
|
resized_image = Image.fromarray(white_canvas)
|
||||||
else:
|
else:
|
||||||
scale = (budget_height * budget_width / (height * width))**(1/2)
|
scale = (budget_height * budget_width / (height * width))**(1/2)
|
||||||
new_height = int( round(height * scale / 16) * 16)
|
new_height = int( round(height * scale / block_size) * block_size)
|
||||||
new_width = int( round(width * scale / 16) * 16)
|
new_width = int( round(width * scale / block_size) * block_size)
|
||||||
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
||||||
if rm_background and not (ignore_first and i == 0) :
|
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
|
||||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user