mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	optimization for i2v with CausVid
This commit is contained in:
		
							parent
							
								
									9c3c6f3b74
								
							
						
					
					
						commit
						6706709230
					
				@ -492,8 +492,7 @@ class MMSingleStreamBlock(nn.Module):
 | 
			
		||||
        return img, txt
 | 
			
		||||
 | 
			
		||||
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def preprocess_loras(model_filename, sd):
 | 
			
		||||
    def preprocess_loras(self, model_filename, sd):
 | 
			
		||||
        if not "i2v" in model_filename:
 | 
			
		||||
            return sd
 | 
			
		||||
        new_sd = {}
 | 
			
		||||
 | 
			
		||||
@ -330,8 +330,11 @@ class WanI2V:
 | 
			
		||||
                'current_step' :i,
 | 
			
		||||
            })
 | 
			
		||||
              
 | 
			
		||||
              
 | 
			
		||||
            if joint_pass:
 | 
			
		||||
            if guide_scale == 1:
 | 
			
		||||
                noise_pred = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
 | 
			
		||||
                if self._interrupt:
 | 
			
		||||
                    return None      
 | 
			
		||||
            elif joint_pass:
 | 
			
		||||
                if audio_proj == None:
 | 
			
		||||
                    noise_pred_cond, noise_pred_uncond = self.model(
 | 
			
		||||
                        [latent_model_input, latent_model_input],
 | 
			
		||||
@ -347,13 +350,7 @@ class WanI2V:
 | 
			
		||||
                if self._interrupt:
 | 
			
		||||
                    return None                
 | 
			
		||||
            else:
 | 
			
		||||
                noise_pred_cond = self.model(
 | 
			
		||||
                    [latent_model_input],
 | 
			
		||||
                    context=[context],
 | 
			
		||||
                    audio_scale = None if audio_scale == None else [audio_scale],
 | 
			
		||||
                    x_id=0,
 | 
			
		||||
                    **kwargs,
 | 
			
		||||
                )[0]
 | 
			
		||||
                noise_pred_cond = self.model( [latent_model_input], context=[context], audio_scale = None if audio_scale == None else [audio_scale], x_id=0, **kwargs, )[0]
 | 
			
		||||
                if self._interrupt:
 | 
			
		||||
                    return None
 | 
			
		||||
                
 | 
			
		||||
@ -377,22 +374,24 @@ class WanI2V:
 | 
			
		||||
                    return None                
 | 
			
		||||
            del latent_model_input
 | 
			
		||||
 | 
			
		||||
            # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
 | 
			
		||||
            if cfg_star_switch:
 | 
			
		||||
                positive_flat = noise_pred_cond.view(batch_size, -1)  
 | 
			
		||||
                negative_flat = noise_pred_uncond.view(batch_size, -1)  
 | 
			
		||||
            if guide_scale > 1:
 | 
			
		||||
                # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
 | 
			
		||||
                if cfg_star_switch:
 | 
			
		||||
                    positive_flat = noise_pred_cond.view(batch_size, -1)  
 | 
			
		||||
                    negative_flat = noise_pred_uncond.view(batch_size, -1)  
 | 
			
		||||
 | 
			
		||||
                alpha = optimized_scale(positive_flat,negative_flat)
 | 
			
		||||
                alpha = alpha.view(batch_size, 1, 1, 1)
 | 
			
		||||
                    alpha = optimized_scale(positive_flat,negative_flat)
 | 
			
		||||
                    alpha = alpha.view(batch_size, 1, 1, 1)
 | 
			
		||||
 | 
			
		||||
                if (i <= cfg_zero_step):
 | 
			
		||||
                    noise_pred = noise_pred_cond*0.  # it would be faster not to compute noise_pred...
 | 
			
		||||
                    if (i <= cfg_zero_step):
 | 
			
		||||
                        noise_pred = noise_pred_cond*0.  # it would be faster not to compute noise_pred...
 | 
			
		||||
                    else:
 | 
			
		||||
                        noise_pred_uncond *= alpha
 | 
			
		||||
                if audio_scale == None:
 | 
			
		||||
                    noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)            
 | 
			
		||||
                else:
 | 
			
		||||
                    noise_pred_uncond *= alpha
 | 
			
		||||
            if audio_scale == None:
 | 
			
		||||
                noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)            
 | 
			
		||||
            else:
 | 
			
		||||
                noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond  - noise_pred_noaudio)                
 | 
			
		||||
                    noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond  - noise_pred_noaudio) 
 | 
			
		||||
                              
 | 
			
		||||
            noise_pred_uncond, noise_pred_noaudio = None, None
 | 
			
		||||
            temp_x0 = sample_scheduler.step(
 | 
			
		||||
                noise_pred.unsqueeze(0),
 | 
			
		||||
 | 
			
		||||
@ -589,8 +589,7 @@ class MLPProj(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def preprocess_loras(model_filename, sd):
 | 
			
		||||
    def preprocess_loras(self, model_filename, sd):
 | 
			
		||||
 | 
			
		||||
        first = next(iter(sd), None)
 | 
			
		||||
        if first == None:
 | 
			
		||||
@ -634,8 +633,8 @@ class WanModel(ModelMixin, ConfigMixin):
 | 
			
		||||
                    print(f"Lora alpha'{alpha_key}' is missing")
 | 
			
		||||
            new_sd.update(new_alphas)
 | 
			
		||||
            sd = new_sd
 | 
			
		||||
 | 
			
		||||
        if "text2video" in model_filename:
 | 
			
		||||
        from wgp import test_class_i2v 
 | 
			
		||||
        if not test_class_i2v(model_filename):
 | 
			
		||||
            new_sd = {}
 | 
			
		||||
            # convert loras for i2v to t2v
 | 
			
		||||
            for k,v in sd.items():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user