intermediate commit

This commit is contained in:
DeepBeepMeep 2025-09-09 21:41:35 +02:00
parent 99fd9aea32
commit f9f63cbc79
20 changed files with 897 additions and 571 deletions

View File

@ -7,8 +7,6 @@
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
], ],
"image_outputs": true,
"reference_image": true,
"flux-model": "flux-dev-kontext" "flux-model": "flux-dev-kontext"
}, },
"prompt": "add a hat", "prompt": "add a hat",

View File

@ -6,8 +6,6 @@
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
"URLs": "flux", "URLs": "flux",
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
"image_outputs": true,
"reference_image": true,
"flux-model": "flux-dev-uso" "flux-model": "flux-dev-uso"
}, },
"prompt": "the man is wearing a hat", "prompt": "the man is wearing a hat",

View File

@ -9,9 +9,7 @@
], ],
"attention": { "attention": {
"<89": "sdpa" "<89": "sdpa"
}, }
"reference_image": true,
"image_outputs": true
}, },
"prompt": "add a hat", "prompt": "add a hat",
"resolution": "1280x720", "resolution": "1280x720",

View File

@ -13,28 +13,41 @@ class family_handler():
flux_schnell = flux_model == "flux-schnell" flux_schnell = flux_model == "flux-schnell"
flux_chroma = flux_model == "flux-chroma" flux_chroma = flux_model == "flux-chroma"
flux_uso = flux_model == "flux-dev-uso" flux_uso = flux_model == "flux-dev-uso"
model_def_output = { flux_kontext = flux_model == "flux-dev-kontext"
extra_model_def = {
"image_outputs" : True, "image_outputs" : True,
"no_negative_prompt" : not flux_chroma, "no_negative_prompt" : not flux_chroma,
} }
if flux_chroma: if flux_chroma:
model_def_output["guidance_max_phases"] = 1 extra_model_def["guidance_max_phases"] = 1
elif not flux_schnell: elif not flux_schnell:
model_def_output["embedded_guidance"] = True extra_model_def["embedded_guidance"] = True
if flux_uso : if flux_uso :
model_def_output["any_image_refs_relative_size"] = True extra_model_def["any_image_refs_relative_size"] = True
model_def_output["no_background_removal"] = True extra_model_def["no_background_removal"] = True
extra_model_def["image_ref_choices"] = {
model_def_output["image_ref_choices"] = {
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
("Up to two Images are Style Images", "KIJ")], ("Up to two Images are Style Images", "KIJ")],
"default": "KI", "default": "KI",
"letters_filter": "KIJ", "letters_filter": "KIJ",
"label": "Reference Images / Style Images" "label": "Reference Images / Style Images"
} }
model_def_output["lock_image_refs_ratios"] = True
if flux_kontext:
extra_model_def["image_ref_choices"] = {
"choices": [
("None", ""),
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
("Conditional Images are People / Objects", "I"),
],
"letters_filter": "KI",
}
return model_def_output
extra_model_def["lock_image_refs_ratios"] = True
return extra_model_def
@staticmethod @staticmethod
def query_supported_types(): def query_supported_types():
@ -122,10 +135,12 @@ class family_handler():
def update_default_settings(base_model_type, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev") flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso" flux_uso = flux_model == "flux-dev-uso"
flux_kontext = flux_model == "flux-dev-kontext"
ui_defaults.update({ ui_defaults.update({
"embedded_guidance": 2.5, "embedded_guidance": 2.5,
}) })
if model_def.get("reference_image", False):
if flux_kontext or flux_uso:
ui_defaults.update({ ui_defaults.update({
"video_prompt_type": "KI", "video_prompt_type": "KI",
}) })

View File

@ -24,44 +24,6 @@ from .util import (
from PIL import Image from PIL import Image
def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
target_height_ref1 = int(target_height_ref1 // 64 * 64)
target_width_ref1 = int(target_width_ref1 // 64 * 64)
h, w = image.shape[-2:]
if h < target_height_ref1 or w < target_width_ref1:
# 计算长宽比
aspect_ratio = w / h
if h < target_height_ref1:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
if new_w < target_width_ref1:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
else:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
if new_h < target_height_ref1:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
else:
aspect_ratio = w / h
tgt_aspect_ratio = target_width_ref1 / target_height_ref1
if aspect_ratio > tgt_aspect_ratio:
new_h = target_height_ref1
new_w = new_h * aspect_ratio
else:
new_w = target_width_ref1
new_h = new_w / aspect_ratio
# 使用 TVF.resize 进行图像缩放
image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
# 计算中心裁剪的参数
top = (image.shape[-2] - target_height_ref1) // 2
left = (image.shape[-1] - target_width_ref1) // 2
# 使用 TVF.crop 进行中心裁剪
image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
return image
def stitch_images(img1, img2): def stitch_images(img1, img2):
# Resize img2 to match img1's height # Resize img2 to match img1's height
width1, height1 = img1.size width1, height1 = img1.size
@ -171,8 +133,6 @@ class model_factory:
device="cuda" device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso'] flux_dev_uso = self.name in ['flux-dev-uso']
image_stiching = not self.name in ['flux-dev-uso'] #and False image_stiching = not self.name in ['flux-dev-uso'] #and False
# image_refs_relative_size = 100
crop = False
input_ref_images = [] if input_ref_images is None else input_ref_images[:] input_ref_images = [] if input_ref_images is None else input_ref_images[:]
ref_style_imgs = [] ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0: if "I" in video_prompt_type and len(input_ref_images) > 0:
@ -186,36 +146,15 @@ class model_factory:
if image_stiching: if image_stiching:
# image stiching method # image stiching method
stiched = input_ref_images[0] stiched = input_ref_images[0]
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
# actual rescale will happen in prepare_kontext
for new_img in input_ref_images[1:]: for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img) stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched] input_ref_images = [stiched]
else: else:
first_ref = 0 # latents stiching with resize
if "K" in video_prompt_type: for i in range(len(input_ref_images)):
# image latents tiling method
w, h = input_ref_images[0].size
if crop :
img = convert_image_to_tensor(input_ref_images[0])
img = resize_and_centercrop_image(img, height, width)
input_ref_images[0] = convert_tensor_to_image(img)
else:
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
first_ref = 1
for i in range(first_ref,len(input_ref_images)):
w, h = input_ref_images[i].size w, h = input_ref_images[i].size
if crop: image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
img = convert_image_to_tensor(input_ref_images[i]) input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
input_ref_images[i] = convert_tensor_to_image(img)
else:
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else: else:
input_ref_images = None input_ref_images = None

View File

@ -861,11 +861,6 @@ class HunyuanVideoSampler(Inference):
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
else: else:
if self.avatar: if self.avatar:
w, h = input_ref_images.size
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
if target_width != w or target_height != h:
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
concat_dict = {'mode': 'timecat', 'bias': -1} concat_dict = {'mode': 'timecat', 'bias': -1}
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else: else:

View File

@ -51,6 +51,23 @@ class family_handler():
extra_model_def["tea_cache"] = True extra_model_def["tea_cache"] = True
extra_model_def["mag_cache"] = True extra_model_def["mag_cache"] = True
if base_model_type in ["hunyuan_custom_edit"]:
extra_model_def["guide_preprocessing"] = {
"selection": ["MV", "PMV"],
}
extra_model_def["mask_preprocessing"] = {
"selection": ["A", "NA"],
"default" : "NA"
}
if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Reference Image", "I")],
"letters_filter":"I",
"visible": False,
}
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
@ -141,6 +158,10 @@ class family_handler():
return hunyuan_model, pipe return hunyuan_model, pipe
@staticmethod
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
pass
@staticmethod @staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults["embedded_guidance_scale"]= 6.0 ui_defaults["embedded_guidance_scale"]= 6.0

View File

@ -300,9 +300,6 @@ class LTXV:
prefix_size, height, width = input_video.shape[-3:] prefix_size, height, width = input_video.shape[-3:]
else: else:
if image_start != None: if image_start != None:
frame_width, frame_height = image_start.size
if fit_into_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_media_paths.append(image_start.unsqueeze(1))
conditioning_start_frames.append(0) conditioning_start_frames.append(0)
conditioning_control_frames.append(False) conditioning_control_frames.append(False)

View File

@ -26,6 +26,15 @@ class family_handler():
extra_model_def["sliding_window"] = True extra_model_def["sliding_window"] = True
extra_model_def["image_prompt_types_allowed"] = "TSEV" extra_model_def["image_prompt_types_allowed"] = "TSEV"
extra_model_def["guide_preprocessing"] = {
"selection": ["", "PV", "DV", "EV", "V"],
"labels" : { "V": "Use LTXV raw format"}
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A", "NA", "XA", "XNA"],
}
return extra_model_def return extra_model_def
@staticmethod @staticmethod

View File

@ -28,7 +28,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler from diffusers import FlowMatchEulerDiscreteScheduler
from PIL import Image from PIL import Image
from shared.utils.utils import calculate_new_dimensions from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
XLA_AVAILABLE = False XLA_AVAILABLE = False
@ -563,6 +563,8 @@ class QwenImagePipeline(): #DiffusionPipeline
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512, max_sequence_length: int = 512,
image = None, image = None,
image_mask = None,
denoising_strength = 0,
callback=None, callback=None,
pipeline=None, pipeline=None,
loras_slists=None, loras_slists=None,
@ -694,14 +696,33 @@ class QwenImagePipeline(): #DiffusionPipeline
image_width = image_width // multiple_of * multiple_of image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of image_height = image_height // multiple_of * multiple_of
ref_height, ref_width = 1568, 672 ref_height, ref_width = 1568, 672
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
if image_height * image_width > ref_height * ref_width:
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) if image_mask is None:
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
if image_height * image_width > ref_height * ref_width:
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
if (image_width,image_height) != image.size:
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
image_mask_latents = None
else:
# _, image_width, image_height = min(
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
# )
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
height, width = image_height, image_width
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
prompt_image = image prompt_image = image
image = self.image_processor.preprocess(image, image_height, image_width) if image.size != (image_width, image_height):
image = image.unsqueeze(2) image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
image.save("nnn.png")
image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
has_neg_prompt = negative_prompt is not None or ( has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@ -744,6 +765,8 @@ class QwenImagePipeline(): #DiffusionPipeline
generator, generator,
latents, latents,
) )
original_image_latents = None if image_latents is None else image_latents.clone()
if image is not None: if image is not None:
img_shapes = [ img_shapes = [
[ [
@ -788,6 +811,15 @@ class QwenImagePipeline(): #DiffusionPipeline
negative_txt_seq_lens = ( negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
) )
morph = False
if image_mask_latents is not None and denoising_strength <= 1.:
first_step = int(len(timesteps) * (1. - denoising_strength))
if not morph:
latent_noise_factor = timesteps[first_step]/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
# 6. Denoising loop # 6. Denoising loop
self.scheduler.set_begin_index(0) self.scheduler.set_begin_index(0)
@ -797,10 +829,15 @@ class QwenImagePipeline(): #DiffusionPipeline
update_loras_slists(self.transformer, loras_slists, updated_num_steps) update_loras_slists(self.transformer, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps)
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
continue continue
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
self._current_timestep = t self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
@ -865,6 +902,12 @@ class QwenImagePipeline(): #DiffusionPipeline
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if image_mask_latents is not None:
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
latent_noise_factor = next_t / 1000
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
noisy_image = None
if latents.dtype != latents_dtype: if latents.dtype != latents_dtype:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
@ -878,7 +921,7 @@ class QwenImagePipeline(): #DiffusionPipeline
self._current_timestep = None self._current_timestep = None
if output_type == "latent": if output_type == "latent":
image = latents output_image = latents
else: else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
@ -891,7 +934,9 @@ class QwenImagePipeline(): #DiffusionPipeline
latents.device, latents.dtype latents.device, latents.dtype
) )
latents = latents / latents_std + latents_mean latents = latents / latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
if image_mask is not None:
output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt
return image return output_image

View File

@ -9,7 +9,7 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
class family_handler(): class family_handler():
@staticmethod @staticmethod
def query_model_def(base_model_type, model_def): def query_model_def(base_model_type, model_def):
model_def_output = { extra_model_def = {
"image_outputs" : True, "image_outputs" : True,
"sample_solvers":[ "sample_solvers":[
("Default", "default"), ("Default", "default"),
@ -18,8 +18,18 @@ class family_handler():
"lock_image_refs_ratios": True, "lock_image_refs_ratios": True,
} }
if base_model_type in ["qwen_image_edit_20B"]:
extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
("None", ""),
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
("Conditional Images are People / Objects", "I"),
],
"letters_filter": "KI",
}
return model_def_output return extra_model_def
@staticmethod @staticmethod
def query_supported_types(): def query_supported_types():
@ -75,14 +85,18 @@ class family_handler():
if ui_defaults.get("sample_solver", "") == "": if ui_defaults.get("sample_solver", "") == "":
ui_defaults["sample_solver"] = "default" ui_defaults["sample_solver"] = "default"
if settings_version < 2.32:
ui_defaults["denoising_strength"] = 1.
@staticmethod @staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({ ui_defaults.update({
"guidance_scale": 4, "guidance_scale": 4,
"sample_solver": "default", "sample_solver": "default",
}) })
if model_def.get("reference_image", False): if base_model_type in ["qwen_image_edit_20B"]:
ui_defaults.update({ ui_defaults.update({
"video_prompt_type": "KI", "video_prompt_type": "KI",
"denoising_strength" : 1.,
}) })

View File

@ -103,6 +103,8 @@ class model_factory():
n_prompt = None, n_prompt = None,
sampling_steps: int = 20, sampling_steps: int = 20,
input_ref_images = None, input_ref_images = None,
image_guide= None,
image_mask= None,
width= 832, width= 832,
height=480, height=480,
guide_scale: float = 4, guide_scale: float = 4,
@ -114,6 +116,7 @@ class model_factory():
VAE_tile_size = None, VAE_tile_size = None,
joint_pass = True, joint_pass = True,
sample_solver='default', sample_solver='default',
denoising_strength = 1.,
**bbargs **bbargs
): ):
# Generate with different aspect ratios # Generate with different aspect ratios
@ -174,8 +177,9 @@ class model_factory():
if n_prompt is None or len(n_prompt) == 0: if n_prompt is None or len(n_prompt) == 0:
n_prompt= "text, watermark, copyright, blurry, low resolution" n_prompt= "text, watermark, copyright, blurry, low resolution"
if image_guide is not None:
if input_ref_images is not None: input_ref_images = [image_guide]
elif input_ref_images is not None:
# image stiching method # image stiching method
stiched = input_ref_images[0] stiched = input_ref_images[0]
if "K" in video_prompt_type : if "K" in video_prompt_type :
@ -190,6 +194,7 @@ class model_factory():
prompt=input_prompt, prompt=input_prompt,
negative_prompt=n_prompt, negative_prompt=n_prompt,
image = input_ref_images, image = input_ref_images,
image_mask = image_mask,
width=width, width=width,
height=height, height=height,
num_inference_steps=sampling_steps, num_inference_steps=sampling_steps,
@ -199,6 +204,7 @@ class model_factory():
pipeline=self, pipeline=self,
loras_slists=loras_slists, loras_slists=loras_slists,
joint_pass = joint_pass, joint_pass = joint_pass,
denoising_strength=denoising_strength,
generator=torch.Generator(device="cuda").manual_seed(seed) generator=torch.Generator(device="cuda").manual_seed(seed)
) )
if image is None: return None if image is None: return None

View File

@ -261,7 +261,7 @@ class WanAny2V:
def vace_latent(self, z, m): def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
from shared.utils.utils import save_image from shared.utils.utils import save_image
ref_width, ref_height = ref_img.size ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None: if (ref_height, ref_width) == image_size and outpainting_dims == None:
@ -270,18 +270,23 @@ class WanAny2V:
else: else:
if outpainting_dims != None: if outpainting_dims != None:
final_height, final_width = image_size final_height, final_width = image_size
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
else: else:
canvas_height, canvas_width = image_size canvas_height, canvas_width = image_size
scale = min(canvas_height / ref_height, canvas_width / ref_width) if full_frame:
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
if fill_max and (canvas_height - new_height) < 16:
new_height = canvas_height new_height = canvas_height
if fill_max and (canvas_width - new_width) < 16:
new_width = canvas_width new_width = canvas_width
top = (canvas_height - new_height) // 2 top = left = 0
left = (canvas_width - new_width) // 2 else:
# if fill_max and (canvas_height - new_height) < 16:
# new_height = canvas_height
# if fill_max and (canvas_width - new_width) < 16:
# new_width = canvas_width
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if outpainting_dims != None: if outpainting_dims != None:
@ -302,7 +307,7 @@ class WanAny2V:
canvas = canvas.to(device) canvas = canvas.to(device)
return ref_img.to(device), canvas return ref_img.to(device), canvas
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = [] image_sizes = []
trim_video_guide = len(keep_video_guide_frames) trim_video_guide = len(keep_video_guide_frames)
def conv_tensor(t, device): def conv_tensor(t, device):
@ -533,22 +538,16 @@ class WanAny2V:
any_end_frame = False any_end_frame = False
if image_start is None: if image_start is None:
if infinitetalk: if infinitetalk:
new_shot = "Q" in video_prompt_type
if input_frames is not None: if input_frames is not None:
image_ref = input_frames[:, 0] image_ref = input_frames[:, 0]
if input_video is None: input_video = input_frames[:, 0:1]
new_shot = "Q" in video_prompt_type
else: else:
if pre_video_frame is None: if input_ref_images is None:
new_shot = True if pre_video_frame is None: raise Exception("Missing Reference Image")
else: input_ref_images = [pre_video_frame]
if input_ref_images is None:
input_ref_images, new_shot = [pre_video_frame], False
else:
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
if input_ref_images is None: raise Exception("Missing Reference Image")
new_shot = new_shot and window_no <= len(input_ref_images) new_shot = new_shot and window_no <= len(input_ref_images)
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot: if new_shot or input_video is None:
input_video = image_ref.unsqueeze(1) input_video = image_ref.unsqueeze(1)
else: else:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot

View File

@ -35,7 +35,7 @@ class family_handler():
"label" : "Generation Type" "label" : "Generation Type"
} }
extra_model_def["image_prompt_types_allowed"] = "TSEV" extra_model_def["image_prompt_types_allowed"] = "TSV"
return extra_model_def return extra_model_def

View File

@ -110,19 +110,79 @@ class family_handler():
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
"mag_cache" : True, "mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"convert_image_guide_to_video" : True,
"sample_solvers":[ "sample_solvers":[
("unipc", "unipc"), ("unipc", "unipc"),
("euler", "euler"), ("euler", "euler"),
("dpm++", "dpm++"), ("dpm++", "dpm++"),
("flowmatch causvid", "causvid"), ] ("flowmatch causvid", "causvid"), ]
}) })
if base_model_type in ["t2v"]:
extra_model_def["guide_custom_choices"] = {
"choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")],
"default": "",
"letters_filter": "GUV",
"label": "Video to Video"
}
if base_model_type in ["infinitetalk"]: if base_model_type in ["infinitetalk"]:
extra_model_def["no_background_removal"] = True extra_model_def["no_background_removal"] = True
# extra_model_def["at_least_one_image_ref_needed"] = True extra_model_def["all_image_refs_are_background_ref"] = True
extra_model_def["guide_custom_choices"] = {
"choices":[
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
],
"default": "KI",
"letters_filter": "RGUVQKI",
"label": "Video to Video",
"show_label" : False,
}
# extra_model_def["at_least_one_image_ref_needed"] = True
if vace_class:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
"labels" : { "V": "Use Vace raw format"}
}
extra_model_def["mask_preprocessing"] = {
"selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
}
extra_model_def["image_ref_choices"] = {
"choices": [("None", ""),
("Inject only People / Objects", "I"),
("Inject Landscape and then People / Objects", "KI"),
("Inject Frames and then People / Objects", "FI"),
],
"letters_filter": "KFI",
}
if base_model_type in ["standin"] or vace_class:
extra_model_def["lock_image_refs_ratios"] = True extra_model_def["lock_image_refs_ratios"] = True
if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
("No Reference Image", ""),
("Reference Image is a Person Face", "I"),
],
"letters_filter":"I",
}
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Reference Image", "I")],
"letters_filter":"I",
"visible": False,
}
if base_model_type in ["recam_1.3B"]: if base_model_type in ["recam_1.3B"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["model_modes"] = { extra_model_def["model_modes"] = {
@ -141,10 +201,18 @@ class family_handler():
"default": 1, "default": 1,
"label" : "Camera Movement Type" "label" : "Camera Movement Type"
} }
extra_model_def["guide_preprocessing"] = {
"selection": ["UV"],
"labels" : { "UV": "Control Video"},
"visible" : False,
}
if vace_class or base_model_type in ["infinitetalk"]: if vace_class or base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TVL" image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]: elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSEVL" image_prompt_types_allowed = "TSVL"
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
image_prompt_types_allowed = "SVL"
elif i2v: elif i2v:
image_prompt_types_allowed = "SEVL" image_prompt_types_allowed = "SEVL"
else: else:

View File

@ -7,7 +7,6 @@ import psutil
# import ffmpeg # import ffmpeg
import imageio import imageio
from PIL import Image from PIL import Image
import cv2 import cv2
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -33,6 +32,8 @@ model_in_GPU = False
matanyone_in_GPU = False matanyone_in_GPU = False
bfloat16_supported = False bfloat16_supported = False
# SAM generator # SAM generator
import copy
class MaskGenerator(): class MaskGenerator():
def __init__(self, sam_checkpoint, device): def __init__(self, sam_checkpoint, device):
global args_device global args_device
@ -89,6 +90,7 @@ def get_frames_from_image(image_input, image_state):
"last_frame_numer": 0, "last_frame_numer": 0,
"fps": None "fps": None
} }
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
set_image_encoder_patch() set_image_encoder_patch()
select_SAM() select_SAM()
@ -717,27 +719,33 @@ def load_unload_models(selected):
def get_vmc_event_handler(): def get_vmc_event_handler():
return load_unload_models return load_unload_models
def export_to_vace_video_input(foreground_video_output):
gr.Info("Masked Video Input transferred to Vace For Inpainting")
return "V#" + str(time.time()), foreground_video_output
def export_image(state, image_output):
def export_image(image_refs, image_output): ui_settings = get_current_model_settings(state)
gr.Info("Masked Image transferred to Current Video") image_refs = ui_settings["image_refs"]
if image_refs == None: if image_refs == None:
image_refs =[] image_refs =[]
image_refs.append( image_output) image_refs.append( image_output)
return image_refs ui_settings["image_refs"] = image_refs
gr.Info("Masked Image transferred to Current Image Generator")
return time.time()
def export_image_mask(image_input, image_mask): def export_image_mask(state, image_input, image_mask):
gr.Info("Input Image & Mask transferred to Current Video") ui_settings = get_current_model_settings(state)
return Image.fromarray(image_input), image_mask ui_settings["image_guide"] = Image.fromarray(image_input)
ui_settings["image_mask"] = image_mask
gr.Info("Input Image & Mask transferred to Current Image Generator")
return time.time()
def export_to_current_video_engine( foreground_video_output, alpha_video_output): def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
ui_settings = get_current_model_settings(state)
ui_settings["video_guide"] = foreground_video_output
ui_settings["video_mask"] = alpha_video_output
gr.Info("Original Video and Full Mask have been transferred") gr.Info("Original Video and Full Mask have been transferred")
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output return time.time()
return foreground_video_output, alpha_video_output
def teleport_to_video_tab(tab_state): def teleport_to_video_tab(tab_state):
@ -746,9 +754,10 @@ def teleport_to_video_tab(tab_state):
return gr.Tabs(selected="video_gen") return gr.Tabs(selected="video_gen")
def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) # my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
global image_output_codec, video_output_codec global image_output_codec, video_output_codec, get_current_model_settings
get_current_model_settings = get_current_model_settings_fn
image_output_codec = server_config.get("image_output_codec", None) image_output_codec = server_config.get("image_output_codec", None)
video_output_codec = server_config.get("video_output_codec", None) video_output_codec = server_config.get("video_output_codec", None)
@ -871,7 +880,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
with gr.Row(): with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100) clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100) add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100) matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
with gr.Row(): with gr.Row():
@ -892,7 +901,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
with gr.Row(visible= True): with gr.Row(visible= True):
export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False) export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
@ -1089,9 +1098,9 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
# with gr.Column(scale=2, visible= True): # with gr.Column(scale=2, visible= True):
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button") export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
# first step: get the image information # first step: get the image information
@ -1148,5 +1157,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
) )
nada = gr.State({})
# clear input
gr.on(
triggers=[image_input.clear], #image_input.change,
fn=restart,
inputs=[],
outputs=[
image_state,
interactive_state,
click_state,
foreground_image_output, alpha_image_output,
template_frame,
image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click,
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
],
queue=False,
show_progress=False)

View File

@ -23,7 +23,7 @@ librosa==0.11.0
speechbrain==1.0.3 speechbrain==1.0.3
# UI & interaction # UI & interaction
gradio==5.23.0 gradio==5.29.0
dashscope dashscope
loguru loguru

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
import gradio as gr import gradio as gr
import PIL import PIL
import time
from PIL import Image as PILImage from PIL import Image as PILImage
FilePath = str FilePath = str
@ -20,6 +21,9 @@ def get_list( objs):
return [] return []
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs] return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
def record_last_action(st, last_action):
st["last_action"] = last_action
st["last_time"] = time.time()
class AdvancedMediaGallery: class AdvancedMediaGallery:
def __init__( def __init__(
self, self,
@ -60,9 +64,10 @@ class AdvancedMediaGallery:
self.state: Optional[gr.State] = None self.state: Optional[gr.State] = None
self._initial_state: Dict[str, Any] = { self._initial_state: Dict[str, Any] = {
"items": items, "items": items,
"selected": (len(items) - 1) if items else None, "selected": (len(items) - 1) if items else 0, # None,
"single": bool(single_image_mode), "single": bool(single_image_mode),
"mode": self.media_mode, "mode": self.media_mode,
"last_action": "",
} }
# ---------------- helpers ---------------- # ---------------- helpers ----------------
@ -210,6 +215,13 @@ class AdvancedMediaGallery:
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
# Mirror the selected index into state and the gallery (server-side selected_index) # Mirror the selected index into state and the gallery (server-side selected_index)
st = get_state(state)
last_time = st.get("last_time", None)
if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
# print(f"ignored:{time.time()}, real {st['selected']}")
return gr.update(selected_index=st["selected"]), st
idx = None idx = None
if evt is not None and hasattr(evt, "index"): if evt is not None and hasattr(evt, "index"):
ix = evt.index ix = evt.index
@ -220,17 +232,28 @@ class AdvancedMediaGallery:
idx = ix[0] * max(1, int(self.columns)) + ix[1] idx = ix[0] * max(1, int(self.columns)) + ix[1]
else: else:
idx = ix[0] idx = ix[0]
st = get_state(state)
n = len(get_list(gallery)) n = len(get_list(gallery))
sel = idx if (idx is not None and 0 <= idx < n) else None sel = idx if (idx is not None and 0 <= idx < n) else None
# print(f"image selected evt index:{sel}/{evt.selected}")
st["selected"] = sel st["selected"] = sel
# return gr.update(selected_index=sel), st return gr.update(), st
# return gr.update(), st
return st def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users upload via the Gallery itself.
# items_filtered = self._filter_items_by_mode(list(value or []))
items_filtered = list(value or [])
st = get_state(state)
new_items = self._paths_from_payload(items_filtered)
st["items"] = new_items
new_sel = len(new_items) - 1
st["selected"] = new_sel
record_last_action(st,"add")
return gr.update(selected_index=new_sel), st
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users add/drag/drop/delete via the Gallery itself. # Fires when users add/drag/drop/delete via the Gallery itself.
items_filtered = self._filter_items_by_mode(list(value or [])) # items_filtered = self._filter_items_by_mode(list(value or []))
items_filtered = list(value or [])
st = get_state(state) st = get_state(state)
st["items"] = items_filtered st["items"] = items_filtered
# Keep selection if still valid, else default to last # Keep selection if still valid, else default to last
@ -240,10 +263,9 @@ class AdvancedMediaGallery:
else: else:
new_sel = old_sel new_sel = old_sel
st["selected"] = new_sel st["selected"] = new_sel
# return gr.update(value=items_filtered, selected_index=new_sel), st st["last_action"] ="gallery_change"
# return gr.update(value=items_filtered), st # print(f"gallery change: set sel {new_sel}")
return gr.update(selected_index=new_sel), st
return gr.update(), st
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
""" """
@ -252,7 +274,8 @@ class AdvancedMediaGallery:
and re-selects the last inserted item. and re-selects the last inserted item.
""" """
# New items (respect image/video mode) # New items (respect image/video mode)
new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
new_items = self._paths_from_payload(files_payload)
st = get_state(state) st = get_state(state)
cur: List[Any] = get_list(gallery) cur: List[Any] = get_list(gallery)
@ -298,30 +321,6 @@ class AdvancedMediaGallery:
if k is not None: if k is not None:
seen_new.add(k) seen_new.add(k)
# Remove any existing occurrences of the incoming items from current list,
# BUT keep the currently selected item even if it's also in incoming.
cur_clean: List[Any] = []
# sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
# for idx, it in enumerate(cur):
# k = key_of(it)
# if it is sel_item:
# cur_clean.append(it)
# continue
# if k is not None and k in seen_new:
# continue # drop duplicate; we'll reinsert at the target spot
# cur_clean.append(it)
# # Compute insertion position: right AFTER the (possibly shifted) selected item
# if sel_item is not None:
# # find sel_item's new index in cur_clean
# try:
# pos_sel = cur_clean.index(sel_item)
# except ValueError:
# # Shouldn't happen, but fall back to end
# pos_sel = len(cur_clean) - 1
# insert_pos = pos_sel + 1
# else:
# insert_pos = len(cur_clean) # no selection -> append at end
insert_pos = min(sel, len(cur) -1) insert_pos = min(sel, len(cur) -1)
cur_clean = cur cur_clean = cur
# Build final list and selection # Build final list and selection
@ -330,6 +329,8 @@ class AdvancedMediaGallery:
st["items"] = merged st["items"] = merged
st["selected"] = new_sel st["selected"] = new_sel
record_last_action(st,"add")
# print(f"gallery add: set sel {new_sel}")
return gr.update(value=merged, selected_index=new_sel), st return gr.update(value=merged, selected_index=new_sel), st
def _on_remove(self, state: Dict[str, Any], gallery) : def _on_remove(self, state: Dict[str, Any], gallery) :
@ -342,8 +343,9 @@ class AdvancedMediaGallery:
return gr.update(value=[], selected_index=None), st return gr.update(value=[], selected_index=None), st
new_sel = min(sel, len(items) - 1) new_sel = min(sel, len(items) - 1)
st["items"] = items; st["selected"] = new_sel st["items"] = items; st["selected"] = new_sel
# return gr.update(value=items, selected_index=new_sel), st record_last_action(st,"remove")
return gr.update(value=items), st # print(f"gallery del: new sel {new_sel}")
return gr.update(value=items, selected_index=new_sel), st
def _on_move(self, delta: int, state: Dict[str, Any], gallery) : def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
@ -354,11 +356,15 @@ class AdvancedMediaGallery:
return gr.update(value=items, selected_index=sel), st return gr.update(value=items, selected_index=sel), st
items[sel], items[j] = items[j], items[sel] items[sel], items[j] = items[j], items[sel]
st["items"] = items; st["selected"] = j st["items"] = items; st["selected"] = j
record_last_action(st,"move")
# print(f"gallery move: set sel {j}")
return gr.update(value=items, selected_index=j), st return gr.update(value=items, selected_index=j), st
def _on_clear(self, state: Dict[str, Any]) : def _on_clear(self, state: Dict[str, Any]) :
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
return gr.update(value=[], selected_index=0), st record_last_action(st,"clear")
# print(f"Clear all")
return gr.update(value=[], selected_index=None), st
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
st = get_state(state); st["single"] = bool(to_single) st = get_state(state); st["single"] = bool(to_single)
@ -382,30 +388,38 @@ class AdvancedMediaGallery:
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
if parent is not None: if parent is not None:
with parent: with parent:
col = self._build_ui() col = self._build_ui(update_form)
else: else:
col = self._build_ui() col = self._build_ui(update_form)
if not update_form: if not update_form:
self._wire_events() self._wire_events()
return col return col
def _build_ui(self) -> gr.Column: def _build_ui(self, update = False) -> gr.Column:
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
self.container = col self.container = col
self.state = gr.State(dict(self._initial_state)) self.state = gr.State(dict(self._initial_state))
self.gallery = gr.Gallery( if update:
label=self.label, self.gallery = gr.update(
value=self._initial_state["items"], value=self._initial_state["items"],
height=self.height, selected_index=self._initial_state["selected"], # server-side selection
columns=self.columns, label=self.label,
show_label=self.show_label, show_label=self.show_label,
preview= True, )
# type="pil", else:
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), self.gallery = gr.Gallery(
selected_index=self._initial_state["selected"], # server-side selection value=self._initial_state["items"],
) label=self.label,
height=self.height,
columns=self.columns,
show_label=self.show_label,
preview= True,
# type="pil", # very slow
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
selected_index=self._initial_state["selected"], # server-side selection
)
# One-line controls # One-line controls
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
@ -418,10 +432,10 @@ class AdvancedMediaGallery:
size="sm", size="sm",
min_width=1, min_width=1,
) )
self.btn_remove = gr.Button("Remove", size="sm", min_width=1) self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
return col return col
@ -430,14 +444,24 @@ class AdvancedMediaGallery:
self.gallery.select( self.gallery.select(
self._on_select, self._on_select,
inputs=[self.state, self.gallery], inputs=[self.state, self.gallery],
outputs=[self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
self.gallery.change( self.gallery.upload(
self._on_upload,
inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state],
trigger_mode="always_last",
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
self.gallery.upload(
self._on_gallery_change, self._on_gallery_change,
inputs=[self.gallery, self.state], inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# Add via UploadButton # Add via UploadButton
@ -445,6 +469,7 @@ class AdvancedMediaGallery:
self._on_add, self._on_add,
inputs=[self.upload_btn, self.state, self.gallery], inputs=[self.upload_btn, self.state, self.gallery],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# Remove selected # Remove selected
@ -452,6 +477,7 @@ class AdvancedMediaGallery:
self._on_remove, self._on_remove,
inputs=[self.state, self.gallery], inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# Reorder using selected index, keep same item selected # Reorder using selected index, keep same item selected
@ -459,11 +485,13 @@ class AdvancedMediaGallery:
lambda st, gallery: self._on_move(-1, st, gallery), lambda st, gallery: self._on_move(-1, st, gallery),
inputs=[self.state, self.gallery], inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
self.btn_right.click( self.btn_right.click(
lambda st, gallery: self._on_move(+1, st, gallery), lambda st, gallery: self._on_move(+1, st, gallery),
inputs=[self.state, self.gallery], inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# Clear all # Clear all
@ -471,6 +499,7 @@ class AdvancedMediaGallery:
self._on_clear, self._on_clear,
inputs=[self.state], inputs=[self.state],
outputs=[self.gallery, self.state], outputs=[self.gallery, self.state],
trigger_mode="always_last",
) )
# ---------------- public API ---------------- # ---------------- public API ----------------

View File

@ -19,6 +19,7 @@ import tempfile
import subprocess import subprocess
import json import json
from functools import lru_cache from functools import lru_cache
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from PIL import Image from PIL import Image
@ -207,30 +208,62 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): def rescale_and_crop(img, w, h):
if fit_into_canvas == None: ow, oh = img.size
target_ratio = w / h
orig_ratio = ow / oh
if orig_ratio > target_ratio:
# Crop width first
nw = int(oh * target_ratio)
img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
else:
# Crop height first
nh = int(ow / target_ratio)
img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
return img.resize((w, h), Image.LANCZOS)
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None or fit_into_canvas == 2:
# return image_height, image_width # return image_height, image_width
return canvas_height, canvas_width return canvas_height, canvas_width
if fit_into_canvas: if fit_into_canvas == 1:
scale1 = min(canvas_height / image_height, canvas_width / image_width) scale1 = min(canvas_height / image_height, canvas_width / image_width)
scale2 = min(canvas_width / image_height, canvas_height / image_width) scale2 = min(canvas_width / image_height, canvas_height / image_width)
scale = max(scale1, scale2) scale = max(scale1, scale2)
else: else: #0 or #2 (crop)
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
new_height = round( image_height * scale / block_size) * block_size new_height = round( image_height * scale / block_size) * block_size
new_width = round( image_width * scale / block_size) * block_size new_width = round( image_width * scale / block_size) * block_size
return new_height, new_width return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
if fit_crop:
image = rescale_and_crop(image, canvas_width, canvas_height)
new_width, new_height = image.size
else:
image_width, image_height = image.size
new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
if rm_background: if rm_background:
session = new_session() session = new_session()
output_list =[] output_list =[]
for i, img in enumerate(img_list): for i, img in enumerate(img_list):
width, height = img.size width, height = img.size
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
if fit_into_canvas: if outpainting_dims is not None:
resized_image =img
elif img.size != (budget_width, budget_height):
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
else:
resized_image =img
elif fit_into_canvas == 1:
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
scale = min(budget_height / height, budget_width / width) scale = min(budget_height / height, budget_width / width)
new_height = int(height * scale) new_height = int(height * scale)
@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
resized_image = Image.fromarray(white_canvas) resized_image = Image.fromarray(white_canvas)
else: else:
scale = (budget_height * budget_width / (height * width))**(1/2) scale = (budget_height * budget_width / (height * width))**(1/2)
new_height = int( round(height * scale / 16) * 16) new_height = int( round(height * scale / block_size) * block_size)
new_width = int( round(width * scale / 16) * 16) new_width = int( round(width * scale / block_size) * block_size)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
if rm_background and not (ignore_first and i == 0) : if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,

839
wgp.py

File diff suppressed because it is too large Load Diff