Queue adaptations

This commit is contained in:
DeepBeepMeep 2025-04-01 20:32:30 +02:00
parent f1b1836a7a
commit 3246d8cf4d
4 changed files with 1061 additions and 570 deletions

File diff suppressed because it is too large Load Diff

View File

@ -41,6 +41,11 @@ def optimized_scale(positive_flat, negative_flat):
return st_star return st_star
def resize_lanczos(img, h, w):
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
class WanI2V: class WanI2V:
@ -90,7 +95,6 @@ class WanI2V:
self.num_train_timesteps = config.num_train_timesteps self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id) shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel( self.text_encoder = T5EncoderModel(
text_len=config.text_len, text_len=config.text_len,
@ -208,12 +212,12 @@ class WanI2V:
- H: Frame height (from max_area) - H: Frame height (from max_area)
- W: Frame width from max_area) - W: Frame width from max_area)
""" """
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) img = TF.to_tensor(img)
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1) lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
any_end_frame = img2 !=None any_end_frame = img2 !=None
if any_end_frame: if any_end_frame:
any_end_frame = True any_end_frame = True
img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device) img2 = TF.to_tensor(img2)
if add_frames_for_end_image: if add_frames_for_end_image:
frame_num +=1 frame_num +=1
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
@ -229,6 +233,15 @@ class WanI2V:
h = lat_h * self.vae_stride[1] h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2] w = lat_w * self.vae_stride[2]
clip_image_size = self.clip.model.image_size
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
img = resize_lanczos(img, clip_image_size, clip_image_size)
img = img.sub_(0.5).div_(0.5).to(self.device)
if img2!= None:
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
@ -273,21 +286,32 @@ class WanI2V:
from mmgp import offload from mmgp import offload
# img_interpolated.save('aaa.png')
# img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
# img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
# img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
# img_interpolated = img_interpolated.clamp(-1, 1)
# img_interpolated = (img_interpolated + 1)/2
# img_interpolated = (img_interpolated*255).type(torch.uint8)
# img_interpolated = img_interpolated.cpu().numpy()
# xxx = Image.fromarray(img_interpolated, 'RGB')
# xxx.save('my.png')
offload.last_offload_obj.unload_all() offload.last_offload_obj.unload_all()
if any_end_frame: if any_end_frame:
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
mean2 = 0 mean2 = 0
enc= torch.concat([ enc= torch.concat([
img_interpolated, img_interpolated,
torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16), torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
img2_interpolated, img_interpolated2,
], dim=1).to(self.device) ], dim=1).to(self.device)
else: else:
enc= torch.concat([ enc= torch.concat([
torch.nn.functional.interpolate( img_interpolated,
img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16), torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
], dim=1).to(self.device) ], dim=1).to(self.device)
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
@ -333,7 +357,8 @@ class WanI2V:
'seq_len': max_seq_len, 'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self 'pipeline' : self,
'callback' : callback
} }
arg_null = { arg_null = {
@ -342,7 +367,8 @@ class WanI2V:
'seq_len': max_seq_len, 'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self 'pipeline' : self,
'callback' : callback
} }
arg_both= { arg_both= {
@ -352,7 +378,8 @@ class WanI2V:
'seq_len': max_seq_len, 'seq_len': max_seq_len,
'y': [y], 'y': [y],
'freqs' : freqs, 'freqs' : freqs,
'pipeline' : self 'pipeline' : self,
'callback' : callback
} }
if offload_model: if offload_model:
@ -363,7 +390,7 @@ class WanI2V:
# self.model.to(self.device) # self.model.to(self.device)
if callback != None: if callback != None:
callback(-1, None) callback(-1, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i) offload.set_step_no_for_lora(self.model, i)
@ -437,7 +464,7 @@ class WanI2V:
del timestep del timestep
if callback is not None: if callback is not None:
callback(i, latent) callback(i, False)
x0 = [latent.to(self.device, dtype=torch.bfloat16)] x0 = [latent.to(self.device, dtype=torch.bfloat16)]
@ -451,7 +478,7 @@ class WanI2V:
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0] video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
if any_end_frame and add_frames_for_end_image: if any_end_frame and add_frames_for_end_image:
# video[:, -1:] = img2_interpolated # video[:, -1:] = img_interpolated2
video = video[:, :-1] video = video[:, :-1]
else: else:

View File

@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
is_uncond=False, is_uncond=False,
max_steps = 0, max_steps = 0,
slg_layers=None, slg_layers=None,
callback = None,
): ):
r""" r"""
Forward pass through the diffusion model Forward pass through the diffusion model
@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
freqs=freqs, freqs=freqs,
# context=context, # context=context,
context_lens=context_lens) context_lens=context_lens)
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
offload.shared_state["layer"] = block_idx offload.shared_state["layer"] = block_idx
if "refresh" in offload.shared_state: if callback != None:
del offload.shared_state["refresh"] callback(-1, False, True)
offload.shared_state["callback"](-1, -1, True)
if pipeline._interrupt: if pipeline._interrupt:
if joint_pass: if joint_pass:
return None, None return None, None

View File

@ -268,7 +268,7 @@ class WanT2V:
if self.model.enable_teacache: if self.model.enable_teacache:
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
if callback != None: if callback != None:
callback(-1, None) callback(-1, True)
for i, t in enumerate(tqdm(timesteps)): for i, t in enumerate(tqdm(timesteps)):
latent_model_input = latents latent_model_input = latents
slg_layers_local = None slg_layers_local = None
@ -322,7 +322,7 @@ class WanT2V:
del temp_x0 del temp_x0
if callback is not None: if callback is not None:
callback(i, latents) callback(i, False)
x0 = latents x0 = latents
if offload_model: if offload_model: