mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
Queue adaptations
This commit is contained in:
parent
f1b1836a7a
commit
3246d8cf4d
1551
gradio_server.py
1551
gradio_server.py
File diff suppressed because it is too large
Load Diff
@ -41,6 +41,11 @@ def optimized_scale(positive_flat, negative_flat):
|
||||
|
||||
return st_star
|
||||
|
||||
def resize_lanczos(img, h, w):
|
||||
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
||||
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||
|
||||
|
||||
class WanI2V:
|
||||
|
||||
@ -90,7 +95,6 @@ class WanI2V:
|
||||
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
@ -208,12 +212,12 @@ class WanI2V:
|
||||
- H: Frame height (from max_area)
|
||||
- W: Frame width from max_area)
|
||||
"""
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
||||
img = TF.to_tensor(img)
|
||||
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
||||
any_end_frame = img2 !=None
|
||||
if any_end_frame:
|
||||
any_end_frame = True
|
||||
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
|
||||
img2 = TF.to_tensor(img2)
|
||||
if add_frames_for_end_image:
|
||||
frame_num +=1
|
||||
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
||||
@ -229,6 +233,15 @@ class WanI2V:
|
||||
h = lat_h * self.vae_stride[1]
|
||||
w = lat_w * self.vae_stride[2]
|
||||
|
||||
clip_image_size = self.clip.model.image_size
|
||||
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
||||
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
||||
img = img.sub_(0.5).div_(0.5).to(self.device)
|
||||
if img2!= None:
|
||||
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
||||
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
||||
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
|
||||
|
||||
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
||||
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
||||
|
||||
@ -273,21 +286,32 @@ class WanI2V:
|
||||
|
||||
from mmgp import offload
|
||||
|
||||
|
||||
# img_interpolated.save('aaa.png')
|
||||
|
||||
# img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
|
||||
|
||||
# img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
|
||||
# img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
|
||||
# img_interpolated = img_interpolated.clamp(-1, 1)
|
||||
# img_interpolated = (img_interpolated + 1)/2
|
||||
# img_interpolated = (img_interpolated*255).type(torch.uint8)
|
||||
# img_interpolated = img_interpolated.cpu().numpy()
|
||||
# xxx = Image.fromarray(img_interpolated, 'RGB')
|
||||
# xxx.save('my.png')
|
||||
|
||||
offload.last_offload_obj.unload_all()
|
||||
if any_end_frame:
|
||||
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
||||
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
||||
mean2 = 0
|
||||
enc= torch.concat([
|
||||
img_interpolated,
|
||||
torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
|
||||
img2_interpolated,
|
||||
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
|
||||
img_interpolated2,
|
||||
], dim=1).to(self.device)
|
||||
else:
|
||||
enc= torch.concat([
|
||||
torch.nn.functional.interpolate(
|
||||
img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
|
||||
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
||||
img_interpolated,
|
||||
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
|
||||
], dim=1).to(self.device)
|
||||
|
||||
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||
@ -333,7 +357,8 @@ class WanI2V:
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self
|
||||
'pipeline' : self,
|
||||
'callback' : callback
|
||||
}
|
||||
|
||||
arg_null = {
|
||||
@ -342,7 +367,8 @@ class WanI2V:
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self
|
||||
'pipeline' : self,
|
||||
'callback' : callback
|
||||
}
|
||||
|
||||
arg_both= {
|
||||
@ -352,7 +378,8 @@ class WanI2V:
|
||||
'seq_len': max_seq_len,
|
||||
'y': [y],
|
||||
'freqs' : freqs,
|
||||
'pipeline' : self
|
||||
'pipeline' : self,
|
||||
'callback' : callback
|
||||
}
|
||||
|
||||
if offload_model:
|
||||
@ -363,7 +390,7 @@ class WanI2V:
|
||||
|
||||
# self.model.to(self.device)
|
||||
if callback != None:
|
||||
callback(-1, None)
|
||||
callback(-1, True)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
@ -437,7 +464,7 @@ class WanI2V:
|
||||
del timestep
|
||||
|
||||
if callback is not None:
|
||||
callback(i, latent)
|
||||
callback(i, False)
|
||||
|
||||
|
||||
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
||||
@ -451,7 +478,7 @@ class WanI2V:
|
||||
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
||||
|
||||
if any_end_frame and add_frames_for_end_image:
|
||||
# video[:, -1:] = img2_interpolated
|
||||
# video[:, -1:] = img_interpolated2
|
||||
video = video[:, :-1]
|
||||
|
||||
else:
|
||||
|
||||
@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
is_uncond=False,
|
||||
max_steps = 0,
|
||||
slg_layers=None,
|
||||
callback = None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
freqs=freqs,
|
||||
# context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
offload.shared_state["layer"] = block_idx
|
||||
if "refresh" in offload.shared_state:
|
||||
del offload.shared_state["refresh"]
|
||||
offload.shared_state["callback"](-1, -1, True)
|
||||
if callback != None:
|
||||
callback(-1, False, True)
|
||||
if pipeline._interrupt:
|
||||
if joint_pass:
|
||||
return None, None
|
||||
|
||||
@ -268,7 +268,7 @@ class WanT2V:
|
||||
if self.model.enable_teacache:
|
||||
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
||||
if callback != None:
|
||||
callback(-1, None)
|
||||
callback(-1, True)
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
slg_layers_local = None
|
||||
@ -322,7 +322,7 @@ class WanT2V:
|
||||
del temp_x0
|
||||
|
||||
if callback is not None:
|
||||
callback(i, latents)
|
||||
callback(i, False)
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user