mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-05 06:29:14 +00:00
Queue adaptations
This commit is contained in:
parent
f1b1836a7a
commit
3246d8cf4d
1367
gradio_server.py
1367
gradio_server.py
File diff suppressed because it is too large
Load Diff
@ -41,6 +41,11 @@ def optimized_scale(positive_flat, negative_flat):
|
|||||||
|
|
||||||
return st_star
|
return st_star
|
||||||
|
|
||||||
|
def resize_lanczos(img, h, w):
|
||||||
|
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
||||||
|
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
||||||
|
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||||
|
|
||||||
|
|
||||||
class WanI2V:
|
class WanI2V:
|
||||||
|
|
||||||
@ -90,7 +95,6 @@ class WanI2V:
|
|||||||
|
|
||||||
self.num_train_timesteps = config.num_train_timesteps
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
self.param_dtype = config.param_dtype
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
shard_fn = partial(shard_model, device_id=device_id)
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
self.text_encoder = T5EncoderModel(
|
self.text_encoder = T5EncoderModel(
|
||||||
text_len=config.text_len,
|
text_len=config.text_len,
|
||||||
@ -208,12 +212,12 @@ class WanI2V:
|
|||||||
- H: Frame height (from max_area)
|
- H: Frame height (from max_area)
|
||||||
- W: Frame width from max_area)
|
- W: Frame width from max_area)
|
||||||
"""
|
"""
|
||||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
img = TF.to_tensor(img)
|
||||||
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
||||||
any_end_frame = img2 !=None
|
any_end_frame = img2 !=None
|
||||||
if any_end_frame:
|
if any_end_frame:
|
||||||
any_end_frame = True
|
any_end_frame = True
|
||||||
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
|
img2 = TF.to_tensor(img2)
|
||||||
if add_frames_for_end_image:
|
if add_frames_for_end_image:
|
||||||
frame_num +=1
|
frame_num +=1
|
||||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||||
@ -229,6 +233,15 @@ class WanI2V:
|
|||||||
h = lat_h * self.vae_stride[1]
|
h = lat_h * self.vae_stride[1]
|
||||||
w = lat_w * self.vae_stride[2]
|
w = lat_w * self.vae_stride[2]
|
||||||
|
|
||||||
|
clip_image_size = self.clip.model.image_size
|
||||||
|
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
||||||
|
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
||||||
|
img = img.sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
if img2!= None:
|
||||||
|
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
||||||
|
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
||||||
|
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
|
||||||
|
|
||||||
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
||||||
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
||||||
|
|
||||||
@ -273,21 +286,32 @@ class WanI2V:
|
|||||||
|
|
||||||
from mmgp import offload
|
from mmgp import offload
|
||||||
|
|
||||||
|
|
||||||
|
# img_interpolated.save('aaa.png')
|
||||||
|
|
||||||
|
# img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||||
|
|
||||||
|
# img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
|
||||||
|
# img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
|
||||||
|
# img_interpolated = img_interpolated.clamp(-1, 1)
|
||||||
|
# img_interpolated = (img_interpolated + 1)/2
|
||||||
|
# img_interpolated = (img_interpolated*255).type(torch.uint8)
|
||||||
|
# img_interpolated = img_interpolated.cpu().numpy()
|
||||||
|
# xxx = Image.fromarray(img_interpolated, 'RGB')
|
||||||
|
# xxx.save('my.png')
|
||||||
|
|
||||||
offload.last_offload_obj.unload_all()
|
offload.last_offload_obj.unload_all()
|
||||||
if any_end_frame:
|
if any_end_frame:
|
||||||
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
|
||||||
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
|
||||||
mean2 = 0
|
mean2 = 0
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
img_interpolated,
|
img_interpolated,
|
||||||
torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
|
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
|
||||||
img2_interpolated,
|
img_interpolated2,
|
||||||
], dim=1).to(self.device)
|
], dim=1).to(self.device)
|
||||||
else:
|
else:
|
||||||
enc= torch.concat([
|
enc= torch.concat([
|
||||||
torch.nn.functional.interpolate(
|
img_interpolated,
|
||||||
img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
|
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
|
||||||
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
|
||||||
], dim=1).to(self.device)
|
], dim=1).to(self.device)
|
||||||
|
|
||||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||||
@ -333,7 +357,8 @@ class WanI2V:
|
|||||||
'seq_len': max_seq_len,
|
'seq_len': max_seq_len,
|
||||||
'y': [y],
|
'y': [y],
|
||||||
'freqs' : freqs,
|
'freqs' : freqs,
|
||||||
'pipeline' : self
|
'pipeline' : self,
|
||||||
|
'callback' : callback
|
||||||
}
|
}
|
||||||
|
|
||||||
arg_null = {
|
arg_null = {
|
||||||
@ -342,7 +367,8 @@ class WanI2V:
|
|||||||
'seq_len': max_seq_len,
|
'seq_len': max_seq_len,
|
||||||
'y': [y],
|
'y': [y],
|
||||||
'freqs' : freqs,
|
'freqs' : freqs,
|
||||||
'pipeline' : self
|
'pipeline' : self,
|
||||||
|
'callback' : callback
|
||||||
}
|
}
|
||||||
|
|
||||||
arg_both= {
|
arg_both= {
|
||||||
@ -352,7 +378,8 @@ class WanI2V:
|
|||||||
'seq_len': max_seq_len,
|
'seq_len': max_seq_len,
|
||||||
'y': [y],
|
'y': [y],
|
||||||
'freqs' : freqs,
|
'freqs' : freqs,
|
||||||
'pipeline' : self
|
'pipeline' : self,
|
||||||
|
'callback' : callback
|
||||||
}
|
}
|
||||||
|
|
||||||
if offload_model:
|
if offload_model:
|
||||||
@ -363,7 +390,7 @@ class WanI2V:
|
|||||||
|
|
||||||
# self.model.to(self.device)
|
# self.model.to(self.device)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, True)
|
||||||
|
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
offload.set_step_no_for_lora(self.model, i)
|
offload.set_step_no_for_lora(self.model, i)
|
||||||
@ -437,7 +464,7 @@ class WanI2V:
|
|||||||
del timestep
|
del timestep
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latent)
|
callback(i, False)
|
||||||
|
|
||||||
|
|
||||||
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
||||||
@ -451,7 +478,7 @@ class WanI2V:
|
|||||||
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||||
|
|
||||||
if any_end_frame and add_frames_for_end_image:
|
if any_end_frame and add_frames_for_end_image:
|
||||||
# video[:, -1:] = img2_interpolated
|
# video[:, -1:] = img_interpolated2
|
||||||
video = video[:, :-1]
|
video = video[:, :-1]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
is_uncond=False,
|
is_uncond=False,
|
||||||
max_steps = 0,
|
max_steps = 0,
|
||||||
slg_layers=None,
|
slg_layers=None,
|
||||||
|
callback = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Forward pass through the diffusion model
|
Forward pass through the diffusion model
|
||||||
@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
freqs=freqs,
|
freqs=freqs,
|
||||||
# context=context,
|
# context=context,
|
||||||
context_lens=context_lens)
|
context_lens=context_lens)
|
||||||
|
|
||||||
for block_idx, block in enumerate(self.blocks):
|
for block_idx, block in enumerate(self.blocks):
|
||||||
offload.shared_state["layer"] = block_idx
|
offload.shared_state["layer"] = block_idx
|
||||||
if "refresh" in offload.shared_state:
|
if callback != None:
|
||||||
del offload.shared_state["refresh"]
|
callback(-1, False, True)
|
||||||
offload.shared_state["callback"](-1, -1, True)
|
|
||||||
if pipeline._interrupt:
|
if pipeline._interrupt:
|
||||||
if joint_pass:
|
if joint_pass:
|
||||||
return None, None
|
return None, None
|
||||||
|
|||||||
@ -268,7 +268,7 @@ class WanT2V:
|
|||||||
if self.model.enable_teacache:
|
if self.model.enable_teacache:
|
||||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||||
if callback != None:
|
if callback != None:
|
||||||
callback(-1, None)
|
callback(-1, True)
|
||||||
for i, t in enumerate(tqdm(timesteps)):
|
for i, t in enumerate(tqdm(timesteps)):
|
||||||
latent_model_input = latents
|
latent_model_input = latents
|
||||||
slg_layers_local = None
|
slg_layers_local = None
|
||||||
@ -322,7 +322,7 @@ class WanT2V:
|
|||||||
del temp_x0
|
del temp_x0
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(i, latents)
|
callback(i, False)
|
||||||
|
|
||||||
x0 = latents
|
x0 = latents
|
||||||
if offload_model:
|
if offload_model:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user