mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
intermediate commit
This commit is contained in:
parent
99fd9aea32
commit
f9f63cbc79
@ -7,8 +7,6 @@
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"image_outputs": true,
|
||||
"reference_image": true,
|
||||
"flux-model": "flux-dev-kontext"
|
||||
},
|
||||
"prompt": "add a hat",
|
||||
|
||||
@ -6,8 +6,6 @@
|
||||
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
|
||||
"URLs": "flux",
|
||||
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
|
||||
"image_outputs": true,
|
||||
"reference_image": true,
|
||||
"flux-model": "flux-dev-uso"
|
||||
},
|
||||
"prompt": "the man is wearing a hat",
|
||||
|
||||
@ -9,9 +9,7 @@
|
||||
],
|
||||
"attention": {
|
||||
"<89": "sdpa"
|
||||
},
|
||||
"reference_image": true,
|
||||
"image_outputs": true
|
||||
}
|
||||
},
|
||||
"prompt": "add a hat",
|
||||
"resolution": "1280x720",
|
||||
|
||||
@ -13,28 +13,41 @@ class family_handler():
|
||||
flux_schnell = flux_model == "flux-schnell"
|
||||
flux_chroma = flux_model == "flux-chroma"
|
||||
flux_uso = flux_model == "flux-dev-uso"
|
||||
model_def_output = {
|
||||
flux_kontext = flux_model == "flux-dev-kontext"
|
||||
|
||||
extra_model_def = {
|
||||
"image_outputs" : True,
|
||||
"no_negative_prompt" : not flux_chroma,
|
||||
}
|
||||
if flux_chroma:
|
||||
model_def_output["guidance_max_phases"] = 1
|
||||
extra_model_def["guidance_max_phases"] = 1
|
||||
elif not flux_schnell:
|
||||
model_def_output["embedded_guidance"] = True
|
||||
extra_model_def["embedded_guidance"] = True
|
||||
if flux_uso :
|
||||
model_def_output["any_image_refs_relative_size"] = True
|
||||
model_def_output["no_background_removal"] = True
|
||||
|
||||
model_def_output["image_ref_choices"] = {
|
||||
extra_model_def["any_image_refs_relative_size"] = True
|
||||
extra_model_def["no_background_removal"] = True
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
|
||||
("Up to two Images are Style Images", "KIJ")],
|
||||
"default": "KI",
|
||||
"letters_filter": "KIJ",
|
||||
"label": "Reference Images / Style Images"
|
||||
}
|
||||
model_def_output["lock_image_refs_ratios"] = True
|
||||
|
||||
if flux_kontext:
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [
|
||||
("None", ""),
|
||||
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
|
||||
("Conditional Images are People / Objects", "I"),
|
||||
],
|
||||
"letters_filter": "KI",
|
||||
}
|
||||
|
||||
return model_def_output
|
||||
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
@ -122,10 +135,12 @@ class family_handler():
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
flux_model = model_def.get("flux-model", "flux-dev")
|
||||
flux_uso = flux_model == "flux-dev-uso"
|
||||
flux_kontext = flux_model == "flux-dev-kontext"
|
||||
ui_defaults.update({
|
||||
"embedded_guidance": 2.5,
|
||||
})
|
||||
if model_def.get("reference_image", False):
|
||||
})
|
||||
|
||||
if flux_kontext or flux_uso:
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "KI",
|
||||
})
|
||||
|
||||
@ -24,44 +24,6 @@ from .util import (
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
|
||||
target_height_ref1 = int(target_height_ref1 // 64 * 64)
|
||||
target_width_ref1 = int(target_width_ref1 // 64 * 64)
|
||||
h, w = image.shape[-2:]
|
||||
if h < target_height_ref1 or w < target_width_ref1:
|
||||
# 计算长宽比
|
||||
aspect_ratio = w / h
|
||||
if h < target_height_ref1:
|
||||
new_h = target_height_ref1
|
||||
new_w = new_h * aspect_ratio
|
||||
if new_w < target_width_ref1:
|
||||
new_w = target_width_ref1
|
||||
new_h = new_w / aspect_ratio
|
||||
else:
|
||||
new_w = target_width_ref1
|
||||
new_h = new_w / aspect_ratio
|
||||
if new_h < target_height_ref1:
|
||||
new_h = target_height_ref1
|
||||
new_w = new_h * aspect_ratio
|
||||
else:
|
||||
aspect_ratio = w / h
|
||||
tgt_aspect_ratio = target_width_ref1 / target_height_ref1
|
||||
if aspect_ratio > tgt_aspect_ratio:
|
||||
new_h = target_height_ref1
|
||||
new_w = new_h * aspect_ratio
|
||||
else:
|
||||
new_w = target_width_ref1
|
||||
new_h = new_w / aspect_ratio
|
||||
# 使用 TVF.resize 进行图像缩放
|
||||
image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
|
||||
# 计算中心裁剪的参数
|
||||
top = (image.shape[-2] - target_height_ref1) // 2
|
||||
left = (image.shape[-1] - target_width_ref1) // 2
|
||||
# 使用 TVF.crop 进行中心裁剪
|
||||
image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
|
||||
return image
|
||||
|
||||
|
||||
def stitch_images(img1, img2):
|
||||
# Resize img2 to match img1's height
|
||||
width1, height1 = img1.size
|
||||
@ -171,8 +133,6 @@ class model_factory:
|
||||
device="cuda"
|
||||
flux_dev_uso = self.name in ['flux-dev-uso']
|
||||
image_stiching = not self.name in ['flux-dev-uso'] #and False
|
||||
# image_refs_relative_size = 100
|
||||
crop = False
|
||||
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
|
||||
ref_style_imgs = []
|
||||
if "I" in video_prompt_type and len(input_ref_images) > 0:
|
||||
@ -186,36 +146,15 @@ class model_factory:
|
||||
if image_stiching:
|
||||
# image stiching method
|
||||
stiched = input_ref_images[0]
|
||||
if "K" in video_prompt_type :
|
||||
w, h = input_ref_images[0].size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
# actual rescale will happen in prepare_kontext
|
||||
for new_img in input_ref_images[1:]:
|
||||
stiched = stitch_images(stiched, new_img)
|
||||
input_ref_images = [stiched]
|
||||
else:
|
||||
first_ref = 0
|
||||
if "K" in video_prompt_type:
|
||||
# image latents tiling method
|
||||
w, h = input_ref_images[0].size
|
||||
if crop :
|
||||
img = convert_image_to_tensor(input_ref_images[0])
|
||||
img = resize_and_centercrop_image(img, height, width)
|
||||
input_ref_images[0] = convert_tensor_to_image(img)
|
||||
else:
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
first_ref = 1
|
||||
|
||||
for i in range(first_ref,len(input_ref_images)):
|
||||
# latents stiching with resize
|
||||
for i in range(len(input_ref_images)):
|
||||
w, h = input_ref_images[i].size
|
||||
if crop:
|
||||
img = convert_image_to_tensor(input_ref_images[i])
|
||||
img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
|
||||
input_ref_images[i] = convert_tensor_to_image(img)
|
||||
else:
|
||||
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
|
||||
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
|
||||
input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
input_ref_images = None
|
||||
|
||||
|
||||
@ -861,11 +861,6 @@ class HunyuanVideoSampler(Inference):
|
||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
|
||||
else:
|
||||
if self.avatar:
|
||||
w, h = input_ref_images.size
|
||||
target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
|
||||
if target_width != w or target_height != h:
|
||||
input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
concat_dict = {'mode': 'timecat', 'bias': -1}
|
||||
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
|
||||
else:
|
||||
|
||||
@ -51,6 +51,23 @@ class family_handler():
|
||||
extra_model_def["tea_cache"] = True
|
||||
extra_model_def["mag_cache"] = True
|
||||
|
||||
if base_model_type in ["hunyuan_custom_edit"]:
|
||||
extra_model_def["guide_preprocessing"] = {
|
||||
"selection": ["MV", "PMV"],
|
||||
}
|
||||
|
||||
extra_model_def["mask_preprocessing"] = {
|
||||
"selection": ["A", "NA"],
|
||||
"default" : "NA"
|
||||
}
|
||||
|
||||
if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [("Reference Image", "I")],
|
||||
"letters_filter":"I",
|
||||
"visible": False,
|
||||
}
|
||||
|
||||
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
|
||||
|
||||
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
|
||||
@ -141,6 +158,10 @@ class family_handler():
|
||||
|
||||
return hunyuan_model, pipe
|
||||
|
||||
@staticmethod
|
||||
def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
ui_defaults["embedded_guidance_scale"]= 6.0
|
||||
|
||||
@ -300,9 +300,6 @@ class LTXV:
|
||||
prefix_size, height, width = input_video.shape[-3:]
|
||||
else:
|
||||
if image_start != None:
|
||||
frame_width, frame_height = image_start.size
|
||||
if fit_into_canvas != None:
|
||||
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
|
||||
conditioning_media_paths.append(image_start.unsqueeze(1))
|
||||
conditioning_start_frames.append(0)
|
||||
conditioning_control_frames.append(False)
|
||||
|
||||
@ -26,6 +26,15 @@ class family_handler():
|
||||
extra_model_def["sliding_window"] = True
|
||||
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
||||
|
||||
extra_model_def["guide_preprocessing"] = {
|
||||
"selection": ["", "PV", "DV", "EV", "V"],
|
||||
"labels" : { "V": "Use LTXV raw format"}
|
||||
}
|
||||
|
||||
extra_model_def["mask_preprocessing"] = {
|
||||
"selection": ["", "A", "NA", "XA", "XNA"],
|
||||
}
|
||||
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Aut
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler
|
||||
from PIL import Image
|
||||
from shared.utils.utils import calculate_new_dimensions
|
||||
from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
|
||||
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
@ -563,6 +563,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
image = None,
|
||||
image_mask = None,
|
||||
denoising_strength = 0,
|
||||
callback=None,
|
||||
pipeline=None,
|
||||
loras_slists=None,
|
||||
@ -694,14 +696,33 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
image_width = image_width // multiple_of * multiple_of
|
||||
image_height = image_height // multiple_of * multiple_of
|
||||
ref_height, ref_width = 1568, 672
|
||||
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
|
||||
if image_height * image_width > ref_height * ref_width:
|
||||
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||
|
||||
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
|
||||
if image_mask is None:
|
||||
if height * width < ref_height * ref_width: ref_height , ref_width = height , width
|
||||
if image_height * image_width > ref_height * ref_width:
|
||||
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||
if (image_width,image_height) != image.size:
|
||||
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
|
||||
image_mask_latents = None
|
||||
else:
|
||||
# _, image_width, image_height = min(
|
||||
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
|
||||
# )
|
||||
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
|
||||
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
|
||||
height, width = image_height, image_width
|
||||
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
|
||||
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
|
||||
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
|
||||
convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
|
||||
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
|
||||
|
||||
prompt_image = image
|
||||
image = self.image_processor.preprocess(image, image_height, image_width)
|
||||
image = image.unsqueeze(2)
|
||||
if image.size != (image_width, image_height):
|
||||
image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
image.save("nnn.png")
|
||||
image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
@ -744,6 +765,8 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
original_image_latents = None if image_latents is None else image_latents.clone()
|
||||
|
||||
if image is not None:
|
||||
img_shapes = [
|
||||
[
|
||||
@ -788,6 +811,15 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
negative_txt_seq_lens = (
|
||||
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
|
||||
)
|
||||
morph = False
|
||||
if image_mask_latents is not None and denoising_strength <= 1.:
|
||||
first_step = int(len(timesteps) * (1. - denoising_strength))
|
||||
if not morph:
|
||||
latent_noise_factor = timesteps[first_step]/1000
|
||||
latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
||||
timesteps = timesteps[first_step:]
|
||||
self.scheduler.timesteps = timesteps
|
||||
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
|
||||
|
||||
# 6. Denoising loop
|
||||
self.scheduler.set_begin_index(0)
|
||||
@ -797,10 +829,15 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
update_loras_slists(self.transformer, loras_slists, updated_num_steps)
|
||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
|
||||
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
|
||||
latent_noise_factor = t/1000
|
||||
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
|
||||
|
||||
self._current_timestep = t
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
@ -865,6 +902,12 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
if image_mask_latents is not None:
|
||||
next_t = timesteps[i+1] if i<len(timesteps)-1 else 0
|
||||
latent_noise_factor = next_t / 1000
|
||||
noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
|
||||
latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
|
||||
noisy_image = None
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@ -878,7 +921,7 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
|
||||
self._current_timestep = None
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
output_image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
@ -891,7 +934,9 @@ class QwenImagePipeline(): #DiffusionPipeline
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
if image_mask is not None:
|
||||
output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt
|
||||
|
||||
|
||||
return image
|
||||
return output_image
|
||||
|
||||
@ -9,7 +9,7 @@ def get_qwen_text_encoder_filename(text_encoder_quantization):
|
||||
class family_handler():
|
||||
@staticmethod
|
||||
def query_model_def(base_model_type, model_def):
|
||||
model_def_output = {
|
||||
extra_model_def = {
|
||||
"image_outputs" : True,
|
||||
"sample_solvers":[
|
||||
("Default", "default"),
|
||||
@ -18,8 +18,18 @@ class family_handler():
|
||||
"lock_image_refs_ratios": True,
|
||||
}
|
||||
|
||||
if base_model_type in ["qwen_image_edit_20B"]:
|
||||
extra_model_def["inpaint_support"] = True
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [
|
||||
("None", ""),
|
||||
("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
|
||||
("Conditional Images are People / Objects", "I"),
|
||||
],
|
||||
"letters_filter": "KI",
|
||||
}
|
||||
|
||||
return model_def_output
|
||||
return extra_model_def
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
@ -75,14 +85,18 @@ class family_handler():
|
||||
if ui_defaults.get("sample_solver", "") == "":
|
||||
ui_defaults["sample_solver"] = "default"
|
||||
|
||||
if settings_version < 2.32:
|
||||
ui_defaults["denoising_strength"] = 1.
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 4,
|
||||
"sample_solver": "default",
|
||||
})
|
||||
if model_def.get("reference_image", False):
|
||||
if base_model_type in ["qwen_image_edit_20B"]:
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "KI",
|
||||
"denoising_strength" : 1.,
|
||||
})
|
||||
|
||||
|
||||
@ -103,6 +103,8 @@ class model_factory():
|
||||
n_prompt = None,
|
||||
sampling_steps: int = 20,
|
||||
input_ref_images = None,
|
||||
image_guide= None,
|
||||
image_mask= None,
|
||||
width= 832,
|
||||
height=480,
|
||||
guide_scale: float = 4,
|
||||
@ -114,6 +116,7 @@ class model_factory():
|
||||
VAE_tile_size = None,
|
||||
joint_pass = True,
|
||||
sample_solver='default',
|
||||
denoising_strength = 1.,
|
||||
**bbargs
|
||||
):
|
||||
# Generate with different aspect ratios
|
||||
@ -174,8 +177,9 @@ class model_factory():
|
||||
|
||||
if n_prompt is None or len(n_prompt) == 0:
|
||||
n_prompt= "text, watermark, copyright, blurry, low resolution"
|
||||
|
||||
if input_ref_images is not None:
|
||||
if image_guide is not None:
|
||||
input_ref_images = [image_guide]
|
||||
elif input_ref_images is not None:
|
||||
# image stiching method
|
||||
stiched = input_ref_images[0]
|
||||
if "K" in video_prompt_type :
|
||||
@ -190,6 +194,7 @@ class model_factory():
|
||||
prompt=input_prompt,
|
||||
negative_prompt=n_prompt,
|
||||
image = input_ref_images,
|
||||
image_mask = image_mask,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=sampling_steps,
|
||||
@ -199,6 +204,7 @@ class model_factory():
|
||||
pipeline=self,
|
||||
loras_slists=loras_slists,
|
||||
joint_pass = joint_pass,
|
||||
denoising_strength=denoising_strength,
|
||||
generator=torch.Generator(device="cuda").manual_seed(seed)
|
||||
)
|
||||
if image is None: return None
|
||||
|
||||
@ -261,7 +261,7 @@ class WanAny2V:
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False):
|
||||
def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_frame = False, outpainting_dims = None, return_mask = False):
|
||||
from shared.utils.utils import save_image
|
||||
ref_width, ref_height = ref_img.size
|
||||
if (ref_height, ref_width) == image_size and outpainting_dims == None:
|
||||
@ -270,18 +270,23 @@ class WanAny2V:
|
||||
else:
|
||||
if outpainting_dims != None:
|
||||
final_height, final_width = image_size
|
||||
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
|
||||
canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
|
||||
else:
|
||||
canvas_height, canvas_width = image_size
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
if fill_max and (canvas_height - new_height) < 16:
|
||||
if full_frame:
|
||||
new_height = canvas_height
|
||||
if fill_max and (canvas_width - new_width) < 16:
|
||||
new_width = canvas_width
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
top = left = 0
|
||||
else:
|
||||
# if fill_max and (canvas_height - new_height) < 16:
|
||||
# new_height = canvas_height
|
||||
# if fill_max and (canvas_width - new_width) < 16:
|
||||
# new_width = canvas_width
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
if outpainting_dims != None:
|
||||
@ -302,7 +307,7 @@ class WanAny2V:
|
||||
canvas = canvas.to(device)
|
||||
return ref_img.to(device), canvas
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
|
||||
image_sizes = []
|
||||
trim_video_guide = len(keep_video_guide_frames)
|
||||
def conv_tensor(t, device):
|
||||
@ -533,22 +538,16 @@ class WanAny2V:
|
||||
any_end_frame = False
|
||||
if image_start is None:
|
||||
if infinitetalk:
|
||||
new_shot = "Q" in video_prompt_type
|
||||
if input_frames is not None:
|
||||
image_ref = input_frames[:, 0]
|
||||
if input_video is None: input_video = input_frames[:, 0:1]
|
||||
new_shot = "Q" in video_prompt_type
|
||||
else:
|
||||
if pre_video_frame is None:
|
||||
new_shot = True
|
||||
else:
|
||||
if input_ref_images is None:
|
||||
input_ref_images, new_shot = [pre_video_frame], False
|
||||
else:
|
||||
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
|
||||
if input_ref_images is None: raise Exception("Missing Reference Image")
|
||||
if input_ref_images is None:
|
||||
if pre_video_frame is None: raise Exception("Missing Reference Image")
|
||||
input_ref_images = [pre_video_frame]
|
||||
new_shot = new_shot and window_no <= len(input_ref_images)
|
||||
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
|
||||
if new_shot:
|
||||
if new_shot or input_video is None:
|
||||
input_video = image_ref.unsqueeze(1)
|
||||
else:
|
||||
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
|
||||
|
||||
@ -35,7 +35,7 @@ class family_handler():
|
||||
"label" : "Generation Type"
|
||||
}
|
||||
|
||||
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
||||
extra_model_def["image_prompt_types_allowed"] = "TSV"
|
||||
|
||||
|
||||
return extra_model_def
|
||||
|
||||
@ -110,19 +110,79 @@ class family_handler():
|
||||
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
|
||||
"mag_cache" : True,
|
||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||
"convert_image_guide_to_video" : True,
|
||||
"sample_solvers":[
|
||||
("unipc", "unipc"),
|
||||
("euler", "euler"),
|
||||
("dpm++", "dpm++"),
|
||||
("flowmatch causvid", "causvid"), ]
|
||||
})
|
||||
|
||||
|
||||
if base_model_type in ["t2v"]:
|
||||
extra_model_def["guide_custom_choices"] = {
|
||||
"choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")],
|
||||
"default": "",
|
||||
"letters_filter": "GUV",
|
||||
"label": "Video to Video"
|
||||
}
|
||||
|
||||
if base_model_type in ["infinitetalk"]:
|
||||
extra_model_def["no_background_removal"] = True
|
||||
# extra_model_def["at_least_one_image_ref_needed"] = True
|
||||
extra_model_def["all_image_refs_are_background_ref"] = True
|
||||
extra_model_def["guide_custom_choices"] = {
|
||||
"choices":[
|
||||
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
|
||||
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
|
||||
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
|
||||
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
|
||||
("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
|
||||
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
|
||||
],
|
||||
"default": "KI",
|
||||
"letters_filter": "RGUVQKI",
|
||||
"label": "Video to Video",
|
||||
"show_label" : False,
|
||||
}
|
||||
|
||||
# extra_model_def["at_least_one_image_ref_needed"] = True
|
||||
if vace_class:
|
||||
extra_model_def["guide_preprocessing"] = {
|
||||
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
|
||||
"labels" : { "V": "Use Vace raw format"}
|
||||
}
|
||||
extra_model_def["mask_preprocessing"] = {
|
||||
"selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"],
|
||||
}
|
||||
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [("None", ""),
|
||||
("Inject only People / Objects", "I"),
|
||||
("Inject Landscape and then People / Objects", "KI"),
|
||||
("Inject Frames and then People / Objects", "FI"),
|
||||
],
|
||||
"letters_filter": "KFI",
|
||||
}
|
||||
|
||||
if base_model_type in ["standin"] or vace_class:
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
|
||||
if base_model_type in ["standin"]:
|
||||
extra_model_def["lock_image_refs_ratios"] = True
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [
|
||||
("No Reference Image", ""),
|
||||
("Reference Image is a Person Face", "I"),
|
||||
],
|
||||
"letters_filter":"I",
|
||||
}
|
||||
|
||||
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
|
||||
extra_model_def["image_ref_choices"] = {
|
||||
"choices": [("Reference Image", "I")],
|
||||
"letters_filter":"I",
|
||||
"visible": False,
|
||||
}
|
||||
|
||||
if base_model_type in ["recam_1.3B"]:
|
||||
extra_model_def["keep_frames_video_guide_not_supported"] = True
|
||||
extra_model_def["model_modes"] = {
|
||||
@ -141,10 +201,18 @@ class family_handler():
|
||||
"default": 1,
|
||||
"label" : "Camera Movement Type"
|
||||
}
|
||||
extra_model_def["guide_preprocessing"] = {
|
||||
"selection": ["UV"],
|
||||
"labels" : { "UV": "Control Video"},
|
||||
"visible" : False,
|
||||
}
|
||||
|
||||
if vace_class or base_model_type in ["infinitetalk"]:
|
||||
image_prompt_types_allowed = "TVL"
|
||||
elif base_model_type in ["ti2v_2_2"]:
|
||||
image_prompt_types_allowed = "TSEVL"
|
||||
image_prompt_types_allowed = "TSVL"
|
||||
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
|
||||
image_prompt_types_allowed = "SVL"
|
||||
elif i2v:
|
||||
image_prompt_types_allowed = "SEVL"
|
||||
else:
|
||||
|
||||
@ -7,7 +7,6 @@ import psutil
|
||||
# import ffmpeg
|
||||
import imageio
|
||||
from PIL import Image
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -33,6 +32,8 @@ model_in_GPU = False
|
||||
matanyone_in_GPU = False
|
||||
bfloat16_supported = False
|
||||
# SAM generator
|
||||
import copy
|
||||
|
||||
class MaskGenerator():
|
||||
def __init__(self, sam_checkpoint, device):
|
||||
global args_device
|
||||
@ -89,6 +90,7 @@ def get_frames_from_image(image_input, image_state):
|
||||
"last_frame_numer": 0,
|
||||
"fps": None
|
||||
}
|
||||
|
||||
image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size)
|
||||
set_image_encoder_patch()
|
||||
select_SAM()
|
||||
@ -717,27 +719,33 @@ def load_unload_models(selected):
|
||||
def get_vmc_event_handler():
|
||||
return load_unload_models
|
||||
|
||||
def export_to_vace_video_input(foreground_video_output):
|
||||
gr.Info("Masked Video Input transferred to Vace For Inpainting")
|
||||
return "V#" + str(time.time()), foreground_video_output
|
||||
|
||||
|
||||
def export_image(image_refs, image_output):
|
||||
gr.Info("Masked Image transferred to Current Video")
|
||||
def export_image(state, image_output):
|
||||
ui_settings = get_current_model_settings(state)
|
||||
image_refs = ui_settings["image_refs"]
|
||||
if image_refs == None:
|
||||
image_refs =[]
|
||||
image_refs.append( image_output)
|
||||
return image_refs
|
||||
ui_settings["image_refs"] = image_refs
|
||||
gr.Info("Masked Image transferred to Current Image Generator")
|
||||
return time.time()
|
||||
|
||||
def export_image_mask(image_input, image_mask):
|
||||
gr.Info("Input Image & Mask transferred to Current Video")
|
||||
return Image.fromarray(image_input), image_mask
|
||||
def export_image_mask(state, image_input, image_mask):
|
||||
ui_settings = get_current_model_settings(state)
|
||||
ui_settings["image_guide"] = Image.fromarray(image_input)
|
||||
ui_settings["image_mask"] = image_mask
|
||||
|
||||
gr.Info("Input Image & Mask transferred to Current Image Generator")
|
||||
return time.time()
|
||||
|
||||
|
||||
def export_to_current_video_engine( foreground_video_output, alpha_video_output):
|
||||
def export_to_current_video_engine(state, foreground_video_output, alpha_video_output):
|
||||
ui_settings = get_current_model_settings(state)
|
||||
ui_settings["video_guide"] = foreground_video_output
|
||||
ui_settings["video_mask"] = alpha_video_output
|
||||
|
||||
gr.Info("Original Video and Full Mask have been transferred")
|
||||
# return "MV#" + str(time.time()), foreground_video_output, alpha_video_output
|
||||
return foreground_video_output, alpha_video_output
|
||||
return time.time()
|
||||
|
||||
|
||||
def teleport_to_video_tab(tab_state):
|
||||
@ -746,9 +754,10 @@ def teleport_to_video_tab(tab_state):
|
||||
return gr.Tabs(selected="video_gen")
|
||||
|
||||
|
||||
def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
|
||||
def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs):
|
||||
# my_tab.select(fn=load_unload_models, inputs=[], outputs=[])
|
||||
global image_output_codec, video_output_codec
|
||||
global image_output_codec, video_output_codec, get_current_model_settings
|
||||
get_current_model_settings = get_current_model_settings_fn
|
||||
|
||||
image_output_codec = server_config.get("image_output_codec", None)
|
||||
video_output_codec = server_config.get("video_output_codec", None)
|
||||
@ -871,7 +880,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
||||
template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image")
|
||||
with gr.Row():
|
||||
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100)
|
||||
add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100)
|
||||
add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100)
|
||||
remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use
|
||||
matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100)
|
||||
with gr.Row():
|
||||
@ -892,7 +901,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
||||
with gr.Row(visible= True):
|
||||
export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False)
|
||||
|
||||
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger,
|
||||
export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||
|
||||
|
||||
@ -1089,9 +1098,9 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
||||
# with gr.Column(scale=2, visible= True):
|
||||
export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button")
|
||||
|
||||
export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger,
|
||||
export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||
export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger,
|
||||
export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger,
|
||||
fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs])
|
||||
|
||||
# first step: get the image information
|
||||
@ -1148,5 +1157,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input,
|
||||
outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn]
|
||||
)
|
||||
|
||||
|
||||
|
||||
nada = gr.State({})
|
||||
# clear input
|
||||
gr.on(
|
||||
triggers=[image_input.clear], #image_input.change,
|
||||
fn=restart,
|
||||
inputs=[],
|
||||
outputs=[
|
||||
image_state,
|
||||
interactive_state,
|
||||
click_state,
|
||||
foreground_image_output, alpha_image_output,
|
||||
template_frame,
|
||||
image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click,
|
||||
add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title
|
||||
],
|
||||
queue=False,
|
||||
show_progress=False)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ librosa==0.11.0
|
||||
speechbrain==1.0.3
|
||||
|
||||
# UI & interaction
|
||||
gradio==5.23.0
|
||||
gradio==5.29.0
|
||||
dashscope
|
||||
loguru
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
|
||||
|
||||
import gradio as gr
|
||||
import PIL
|
||||
import time
|
||||
from PIL import Image as PILImage
|
||||
|
||||
FilePath = str
|
||||
@ -20,6 +21,9 @@ def get_list( objs):
|
||||
return []
|
||||
return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
|
||||
|
||||
def record_last_action(st, last_action):
|
||||
st["last_action"] = last_action
|
||||
st["last_time"] = time.time()
|
||||
class AdvancedMediaGallery:
|
||||
def __init__(
|
||||
self,
|
||||
@ -60,9 +64,10 @@ class AdvancedMediaGallery:
|
||||
self.state: Optional[gr.State] = None
|
||||
self._initial_state: Dict[str, Any] = {
|
||||
"items": items,
|
||||
"selected": (len(items) - 1) if items else None,
|
||||
"selected": (len(items) - 1) if items else 0, # None,
|
||||
"single": bool(single_image_mode),
|
||||
"mode": self.media_mode,
|
||||
"last_action": "",
|
||||
}
|
||||
|
||||
# ---------------- helpers ----------------
|
||||
@ -210,6 +215,13 @@ class AdvancedMediaGallery:
|
||||
|
||||
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
|
||||
# Mirror the selected index into state and the gallery (server-side selected_index)
|
||||
|
||||
st = get_state(state)
|
||||
last_time = st.get("last_time", None)
|
||||
if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
|
||||
# print(f"ignored:{time.time()}, real {st['selected']}")
|
||||
return gr.update(selected_index=st["selected"]), st
|
||||
|
||||
idx = None
|
||||
if evt is not None and hasattr(evt, "index"):
|
||||
ix = evt.index
|
||||
@ -220,17 +232,28 @@ class AdvancedMediaGallery:
|
||||
idx = ix[0] * max(1, int(self.columns)) + ix[1]
|
||||
else:
|
||||
idx = ix[0]
|
||||
st = get_state(state)
|
||||
n = len(get_list(gallery))
|
||||
sel = idx if (idx is not None and 0 <= idx < n) else None
|
||||
# print(f"image selected evt index:{sel}/{evt.selected}")
|
||||
st["selected"] = sel
|
||||
# return gr.update(selected_index=sel), st
|
||||
# return gr.update(), st
|
||||
return st
|
||||
return gr.update(), st
|
||||
|
||||
def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
|
||||
# Fires when users upload via the Gallery itself.
|
||||
# items_filtered = self._filter_items_by_mode(list(value or []))
|
||||
items_filtered = list(value or [])
|
||||
st = get_state(state)
|
||||
new_items = self._paths_from_payload(items_filtered)
|
||||
st["items"] = new_items
|
||||
new_sel = len(new_items) - 1
|
||||
st["selected"] = new_sel
|
||||
record_last_action(st,"add")
|
||||
return gr.update(selected_index=new_sel), st
|
||||
|
||||
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
|
||||
# Fires when users add/drag/drop/delete via the Gallery itself.
|
||||
items_filtered = self._filter_items_by_mode(list(value or []))
|
||||
# items_filtered = self._filter_items_by_mode(list(value or []))
|
||||
items_filtered = list(value or [])
|
||||
st = get_state(state)
|
||||
st["items"] = items_filtered
|
||||
# Keep selection if still valid, else default to last
|
||||
@ -240,10 +263,9 @@ class AdvancedMediaGallery:
|
||||
else:
|
||||
new_sel = old_sel
|
||||
st["selected"] = new_sel
|
||||
# return gr.update(value=items_filtered, selected_index=new_sel), st
|
||||
# return gr.update(value=items_filtered), st
|
||||
|
||||
return gr.update(), st
|
||||
st["last_action"] ="gallery_change"
|
||||
# print(f"gallery change: set sel {new_sel}")
|
||||
return gr.update(selected_index=new_sel), st
|
||||
|
||||
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
|
||||
"""
|
||||
@ -252,7 +274,8 @@ class AdvancedMediaGallery:
|
||||
and re-selects the last inserted item.
|
||||
"""
|
||||
# New items (respect image/video mode)
|
||||
new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
|
||||
# new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
|
||||
new_items = self._paths_from_payload(files_payload)
|
||||
|
||||
st = get_state(state)
|
||||
cur: List[Any] = get_list(gallery)
|
||||
@ -298,30 +321,6 @@ class AdvancedMediaGallery:
|
||||
if k is not None:
|
||||
seen_new.add(k)
|
||||
|
||||
# Remove any existing occurrences of the incoming items from current list,
|
||||
# BUT keep the currently selected item even if it's also in incoming.
|
||||
cur_clean: List[Any] = []
|
||||
# sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
|
||||
# for idx, it in enumerate(cur):
|
||||
# k = key_of(it)
|
||||
# if it is sel_item:
|
||||
# cur_clean.append(it)
|
||||
# continue
|
||||
# if k is not None and k in seen_new:
|
||||
# continue # drop duplicate; we'll reinsert at the target spot
|
||||
# cur_clean.append(it)
|
||||
|
||||
# # Compute insertion position: right AFTER the (possibly shifted) selected item
|
||||
# if sel_item is not None:
|
||||
# # find sel_item's new index in cur_clean
|
||||
# try:
|
||||
# pos_sel = cur_clean.index(sel_item)
|
||||
# except ValueError:
|
||||
# # Shouldn't happen, but fall back to end
|
||||
# pos_sel = len(cur_clean) - 1
|
||||
# insert_pos = pos_sel + 1
|
||||
# else:
|
||||
# insert_pos = len(cur_clean) # no selection -> append at end
|
||||
insert_pos = min(sel, len(cur) -1)
|
||||
cur_clean = cur
|
||||
# Build final list and selection
|
||||
@ -330,6 +329,8 @@ class AdvancedMediaGallery:
|
||||
|
||||
st["items"] = merged
|
||||
st["selected"] = new_sel
|
||||
record_last_action(st,"add")
|
||||
# print(f"gallery add: set sel {new_sel}")
|
||||
return gr.update(value=merged, selected_index=new_sel), st
|
||||
|
||||
def _on_remove(self, state: Dict[str, Any], gallery) :
|
||||
@ -342,8 +343,9 @@ class AdvancedMediaGallery:
|
||||
return gr.update(value=[], selected_index=None), st
|
||||
new_sel = min(sel, len(items) - 1)
|
||||
st["items"] = items; st["selected"] = new_sel
|
||||
# return gr.update(value=items, selected_index=new_sel), st
|
||||
return gr.update(value=items), st
|
||||
record_last_action(st,"remove")
|
||||
# print(f"gallery del: new sel {new_sel}")
|
||||
return gr.update(value=items, selected_index=new_sel), st
|
||||
|
||||
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
|
||||
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
|
||||
@ -354,11 +356,15 @@ class AdvancedMediaGallery:
|
||||
return gr.update(value=items, selected_index=sel), st
|
||||
items[sel], items[j] = items[j], items[sel]
|
||||
st["items"] = items; st["selected"] = j
|
||||
record_last_action(st,"move")
|
||||
# print(f"gallery move: set sel {j}")
|
||||
return gr.update(value=items, selected_index=j), st
|
||||
|
||||
def _on_clear(self, state: Dict[str, Any]) :
|
||||
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
|
||||
return gr.update(value=[], selected_index=0), st
|
||||
record_last_action(st,"clear")
|
||||
# print(f"Clear all")
|
||||
return gr.update(value=[], selected_index=None), st
|
||||
|
||||
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
|
||||
st = get_state(state); st["single"] = bool(to_single)
|
||||
@ -382,30 +388,38 @@ class AdvancedMediaGallery:
|
||||
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
|
||||
if parent is not None:
|
||||
with parent:
|
||||
col = self._build_ui()
|
||||
col = self._build_ui(update_form)
|
||||
else:
|
||||
col = self._build_ui()
|
||||
col = self._build_ui(update_form)
|
||||
if not update_form:
|
||||
self._wire_events()
|
||||
return col
|
||||
|
||||
def _build_ui(self) -> gr.Column:
|
||||
def _build_ui(self, update = False) -> gr.Column:
|
||||
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
|
||||
self.container = col
|
||||
|
||||
self.state = gr.State(dict(self._initial_state))
|
||||
|
||||
self.gallery = gr.Gallery(
|
||||
label=self.label,
|
||||
value=self._initial_state["items"],
|
||||
height=self.height,
|
||||
columns=self.columns,
|
||||
show_label=self.show_label,
|
||||
preview= True,
|
||||
# type="pil",
|
||||
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
|
||||
selected_index=self._initial_state["selected"], # server-side selection
|
||||
)
|
||||
if update:
|
||||
self.gallery = gr.update(
|
||||
value=self._initial_state["items"],
|
||||
selected_index=self._initial_state["selected"], # server-side selection
|
||||
label=self.label,
|
||||
show_label=self.show_label,
|
||||
)
|
||||
else:
|
||||
self.gallery = gr.Gallery(
|
||||
value=self._initial_state["items"],
|
||||
label=self.label,
|
||||
height=self.height,
|
||||
columns=self.columns,
|
||||
show_label=self.show_label,
|
||||
preview= True,
|
||||
# type="pil", # very slow
|
||||
file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
|
||||
selected_index=self._initial_state["selected"], # server-side selection
|
||||
)
|
||||
|
||||
# One-line controls
|
||||
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
|
||||
@ -418,10 +432,10 @@ class AdvancedMediaGallery:
|
||||
size="sm",
|
||||
min_width=1,
|
||||
)
|
||||
self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
|
||||
self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
|
||||
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||
self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||
self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
|
||||
|
||||
return col
|
||||
|
||||
@ -430,14 +444,24 @@ class AdvancedMediaGallery:
|
||||
self.gallery.select(
|
||||
self._on_select,
|
||||
inputs=[self.state, self.gallery],
|
||||
outputs=[self.state],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
|
||||
self.gallery.change(
|
||||
self.gallery.upload(
|
||||
self._on_upload,
|
||||
inputs=[self.gallery, self.state],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
|
||||
self.gallery.upload(
|
||||
self._on_gallery_change,
|
||||
inputs=[self.gallery, self.state],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Add via UploadButton
|
||||
@ -445,6 +469,7 @@ class AdvancedMediaGallery:
|
||||
self._on_add,
|
||||
inputs=[self.upload_btn, self.state, self.gallery],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Remove selected
|
||||
@ -452,6 +477,7 @@ class AdvancedMediaGallery:
|
||||
self._on_remove,
|
||||
inputs=[self.state, self.gallery],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Reorder using selected index, keep same item selected
|
||||
@ -459,11 +485,13 @@ class AdvancedMediaGallery:
|
||||
lambda st, gallery: self._on_move(-1, st, gallery),
|
||||
inputs=[self.state, self.gallery],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
self.btn_right.click(
|
||||
lambda st, gallery: self._on_move(+1, st, gallery),
|
||||
inputs=[self.state, self.gallery],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# Clear all
|
||||
@ -471,6 +499,7 @@ class AdvancedMediaGallery:
|
||||
self._on_clear,
|
||||
inputs=[self.state],
|
||||
outputs=[self.gallery, self.state],
|
||||
trigger_mode="always_last",
|
||||
)
|
||||
|
||||
# ---------------- public API ----------------
|
||||
|
||||
@ -19,6 +19,7 @@ import tempfile
|
||||
import subprocess
|
||||
import json
|
||||
from functools import lru_cache
|
||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||
|
||||
|
||||
from PIL import Image
|
||||
@ -207,30 +208,62 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
|
||||
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
|
||||
return height, width, margin_top, margin_left
|
||||
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
|
||||
if fit_into_canvas == None:
|
||||
def rescale_and_crop(img, w, h):
|
||||
ow, oh = img.size
|
||||
target_ratio = w / h
|
||||
orig_ratio = ow / oh
|
||||
|
||||
if orig_ratio > target_ratio:
|
||||
# Crop width first
|
||||
nw = int(oh * target_ratio)
|
||||
img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
|
||||
else:
|
||||
# Crop height first
|
||||
nh = int(ow / target_ratio)
|
||||
img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
|
||||
|
||||
return img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
|
||||
if fit_into_canvas == None or fit_into_canvas == 2:
|
||||
# return image_height, image_width
|
||||
return canvas_height, canvas_width
|
||||
if fit_into_canvas:
|
||||
if fit_into_canvas == 1:
|
||||
scale1 = min(canvas_height / image_height, canvas_width / image_width)
|
||||
scale2 = min(canvas_width / image_height, canvas_height / image_width)
|
||||
scale = max(scale1, scale2)
|
||||
else:
|
||||
else: #0 or #2 (crop)
|
||||
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
|
||||
|
||||
new_height = round( image_height * scale / block_size) * block_size
|
||||
new_width = round( image_width * scale / block_size) * block_size
|
||||
return new_height, new_width
|
||||
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
|
||||
def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
|
||||
if fit_crop:
|
||||
image = rescale_and_crop(image, canvas_width, canvas_height)
|
||||
new_width, new_height = image.size
|
||||
else:
|
||||
image_width, image_height = image.size
|
||||
new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
|
||||
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
return image, new_height, new_width
|
||||
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
|
||||
if rm_background:
|
||||
session = new_session()
|
||||
|
||||
output_list =[]
|
||||
for i, img in enumerate(img_list):
|
||||
width, height = img.size
|
||||
|
||||
if fit_into_canvas:
|
||||
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
|
||||
if outpainting_dims is not None:
|
||||
resized_image =img
|
||||
elif img.size != (budget_width, budget_height):
|
||||
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
resized_image =img
|
||||
elif fit_into_canvas == 1:
|
||||
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
|
||||
scale = min(budget_height / height, budget_width / width)
|
||||
new_height = int(height * scale)
|
||||
@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
|
||||
resized_image = Image.fromarray(white_canvas)
|
||||
else:
|
||||
scale = (budget_height * budget_width / (height * width))**(1/2)
|
||||
new_height = int( round(height * scale / 16) * 16)
|
||||
new_width = int( round(width * scale / 16) * 16)
|
||||
new_height = int( round(height * scale / block_size) * block_size)
|
||||
new_width = int( round(width * scale / block_size) * block_size)
|
||||
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
||||
if rm_background and not (ignore_first and i == 0) :
|
||||
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
|
||||
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
||||
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user