mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Various Memory Optimisations
This commit is contained in:
		
							parent
							
								
									94d9b4aa4d
								
							
						
					
					
						commit
						52d7ba9260
					
				@ -4,24 +4,33 @@ from transformers import Wav2Vec2Model, Wav2Vec2Processor
 | 
			
		||||
 | 
			
		||||
from .model import FantasyTalkingAudioConditionModel
 | 
			
		||||
from .utils import get_audio_features
 | 
			
		||||
 | 
			
		||||
import gc, torch
 | 
			
		||||
 | 
			
		||||
def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
 | 
			
		||||
    fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
 | 
			
		||||
    from mmgp import offload
 | 
			
		||||
    from accelerate import init_empty_weights
 | 
			
		||||
    from fantasytalking.model import AudioProjModel
 | 
			
		||||
 | 
			
		||||
    torch.set_grad_enabled(False) 
 | 
			
		||||
 | 
			
		||||
    with init_empty_weights():
 | 
			
		||||
        proj_model = AudioProjModel( 768, 2048)
 | 
			
		||||
    offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
 | 
			
		||||
    proj_model.to(device).eval().requires_grad_(False)
 | 
			
		||||
    proj_model.to("cpu").eval().requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
    wav2vec_model_dir = "ckpts/wav2vec"
 | 
			
		||||
    wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
 | 
			
		||||
    wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).to(device).eval().requires_grad_(False)
 | 
			
		||||
    wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
 | 
			
		||||
    wav2vec.to(device)
 | 
			
		||||
    proj_model.to(device)
 | 
			
		||||
    audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
 | 
			
		||||
 | 
			
		||||
    audio_proj_fea = proj_model(audio_wav2vec_fea)
 | 
			
		||||
    pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
 | 
			
		||||
    audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 )  # [b,21,9+8,768]
 | 
			
		||||
    audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 )  # [b,21,9+8,768]    
 | 
			
		||||
    wav2vec, proj_model= None, None
 | 
			
		||||
    gc.collect()
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
 | 
			
		||||
    return audio_proj_split, audio_context_lens
 | 
			
		||||
@ -16,7 +16,7 @@ gradio==5.23.0
 | 
			
		||||
numpy>=1.23.5,<2
 | 
			
		||||
einops
 | 
			
		||||
moviepy==1.0.3
 | 
			
		||||
mmgp==3.4.2
 | 
			
		||||
mmgp==3.4.3
 | 
			
		||||
peft==0.14.0
 | 
			
		||||
mutagen
 | 
			
		||||
pydantic==2.10.6
 | 
			
		||||
 | 
			
		||||
@ -103,7 +103,7 @@ class WanI2V:
 | 
			
		||||
        # dtype = torch.float16
 | 
			
		||||
        self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath= "c:/temp/i2v720p/config.json")
 | 
			
		||||
        self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
 | 
			
		||||
        # offload.change_dtype(self.model, dtype, True)
 | 
			
		||||
        offload.change_dtype(self.model, dtype, True)
 | 
			
		||||
        # offload.save_model(self.model, "wan2.1_image2video_720p_14B_mbf16.safetensors", config_file_path="c:/temp/i2v720p/config.json")
 | 
			
		||||
        # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
 | 
			
		||||
        # offload.save_model(self.model, "wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors",do_quantize=True, config_file_path="c:/temp/i2v720p/config.json")
 | 
			
		||||
@ -403,9 +403,7 @@ class WanI2V:
 | 
			
		||||
            if callback is not None:
 | 
			
		||||
                callback(i, latent, False) 
 | 
			
		||||
 | 
			
		||||
        x0 = [latent]
 | 
			
		||||
 | 
			
		||||
        # x0 = [lat_y]
 | 
			
		||||
        x0 = [latent]        
 | 
			
		||||
        video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
 | 
			
		||||
 | 
			
		||||
        if any_end_frame and add_frames_for_end_image:
 | 
			
		||||
 | 
			
		||||
@ -312,8 +312,6 @@ class WanI2VCrossAttention(WanSelfAttention):
 | 
			
		||||
        del x
 | 
			
		||||
        self.norm_q(q)
 | 
			
		||||
        q= q.view(b, -1, n, d)
 | 
			
		||||
        if audio_scale != None:
 | 
			
		||||
            audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
 | 
			
		||||
        k = self.k(context)
 | 
			
		||||
        self.norm_k(k)
 | 
			
		||||
        k = k.view(b, -1, n, d)
 | 
			
		||||
@ -323,6 +321,8 @@ class WanI2VCrossAttention(WanSelfAttention):
 | 
			
		||||
        del k,v
 | 
			
		||||
        x = pay_attention(qkv_list)
 | 
			
		||||
 | 
			
		||||
        if audio_scale != None:
 | 
			
		||||
            audio_x = self.processor(q, audio_proj, grid_sizes[0], audio_context_lens)
 | 
			
		||||
        k_img = self.k_img(context_img)
 | 
			
		||||
        self.norm_k_img(k_img)
 | 
			
		||||
        k_img = k_img.view(b, -1, n, d)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										53
									
								
								wgp.py
									
									
									
									
									
								
							
							
						
						
									
										53
									
								
								wgp.py
									
									
									
									
									
								
							@ -40,7 +40,7 @@ global_queue_ref = []
 | 
			
		||||
AUTOSAVE_FILENAME = "queue.zip"
 | 
			
		||||
PROMPT_VARS_MAX = 10
 | 
			
		||||
 | 
			
		||||
target_mmgp_version = "3.4.2"
 | 
			
		||||
target_mmgp_version = "3.4.3"
 | 
			
		||||
from importlib.metadata import version
 | 
			
		||||
mmgp_version = version("mmgp")
 | 
			
		||||
if mmgp_version != target_mmgp_version:
 | 
			
		||||
@ -50,6 +50,7 @@ lock = threading.Lock()
 | 
			
		||||
current_task_id = None
 | 
			
		||||
task_id = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def download_ffmpeg():
 | 
			
		||||
    if os.name != 'nt': return
 | 
			
		||||
    exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
 | 
			
		||||
@ -1421,6 +1422,7 @@ for path in  ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
 | 
			
		||||
"wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors"
 | 
			
		||||
]:
 | 
			
		||||
    if Path(os.path.join("ckpts" , path)).is_file():
 | 
			
		||||
        print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
 | 
			
		||||
        os.remove( os.path.join("ckpts" , path))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1511,14 +1513,21 @@ def get_model_filename(model_type, quantization):
 | 
			
		||||
        quantization = "bf16"
 | 
			
		||||
 | 
			
		||||
    if len(choices) <= 1:
 | 
			
		||||
        return choices[0]
 | 
			
		||||
    
 | 
			
		||||
    sub_choices = [ name for name in choices if quantization in name]
 | 
			
		||||
    if len(sub_choices) > 0:
 | 
			
		||||
        return sub_choices[0]
 | 
			
		||||
        raw_filename = choices[0]
 | 
			
		||||
    else:
 | 
			
		||||
        return choices[0]
 | 
			
		||||
        sub_choices = [ name for name in choices if quantization in name]
 | 
			
		||||
        if len(sub_choices) > 0:
 | 
			
		||||
            raw_filename = sub_choices[0]
 | 
			
		||||
        else:
 | 
			
		||||
            raw_filename = choices[0]
 | 
			
		||||
    
 | 
			
		||||
    if transformer_dtype == torch.float16   :
 | 
			
		||||
        if "quanto_int8" in raw_filename:
 | 
			
		||||
            raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8")
 | 
			
		||||
        elif "quanto_mbf16_int8":
 | 
			
		||||
            raw_filename= raw_filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
 | 
			
		||||
    return raw_filename
 | 
			
		||||
 | 
			
		||||
def get_settings_file_name(model_filename):
 | 
			
		||||
    return  os.path.join(args.settings, get_model_type(model_filename) + "_settings.json")
 | 
			
		||||
 | 
			
		||||
@ -1599,6 +1608,13 @@ def get_default_settings(filename):
 | 
			
		||||
        ui_defaults["num_inference_steps"] = default_number_steps
 | 
			
		||||
    return ui_defaults
 | 
			
		||||
 | 
			
		||||
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
 | 
			
		||||
if  major < 8:
 | 
			
		||||
    print("Switching to f16 models as GPU architecture doesn't support bf16")
 | 
			
		||||
    transformer_dtype = torch.float16
 | 
			
		||||
else:
 | 
			
		||||
    transformer_dtype = torch.float16 if args.fp16 else torch.bfloat16
 | 
			
		||||
 | 
			
		||||
transformer_types = server_config.get("transformer_types", [])
 | 
			
		||||
transformer_type = transformer_types[0] if len(transformer_types) > 0 else  model_types[0]
 | 
			
		||||
transformer_quantization =server_config.get("transformer_quantization", "int8")
 | 
			
		||||
@ -1892,32 +1908,17 @@ def load_models(model_filename):
 | 
			
		||||
    global transformer_filename
 | 
			
		||||
 | 
			
		||||
    perc_reserved_mem_max = args.perc_reserved_mem_max
 | 
			
		||||
 | 
			
		||||
    major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
 | 
			
		||||
    if  major < 8:
 | 
			
		||||
        print("Switching to f16 model as GPU architecture doesn't support bf16")
 | 
			
		||||
        default_dtype = torch.float16
 | 
			
		||||
    else:
 | 
			
		||||
        default_dtype = torch.float16 if args.fp16 else torch.bfloat16
 | 
			
		||||
    model_filelist = get_dependent_models(model_filename, quantization= transformer_quantization) + [model_filename]
 | 
			
		||||
    updated_model_filename = []
 | 
			
		||||
    for filename in model_filelist: 
 | 
			
		||||
        if default_dtype == torch.float16   :
 | 
			
		||||
            if "quanto_int8" in filename:
 | 
			
		||||
                filename = filename.replace("quanto_int8", "quanto_fp16_int8")
 | 
			
		||||
            elif "quanto_mbf16_int8":
 | 
			
		||||
                filename = filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8")
 | 
			
		||||
        updated_model_filename.append(filename)
 | 
			
		||||
        download_models(filename, text_encoder_filename)
 | 
			
		||||
    model_filelist = updated_model_filename
 | 
			
		||||
    VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
 | 
			
		||||
    mixed_precision_transformer =  server_config.get("mixed_precision","0") == "1"
 | 
			
		||||
    transformer_filename = None
 | 
			
		||||
    new_transformer_filename = model_filelist[-1] 
 | 
			
		||||
    if test_class_i2v(new_transformer_filename):
 | 
			
		||||
        wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
 | 
			
		||||
        wan_model, pipe = load_i2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
 | 
			
		||||
    else:
 | 
			
		||||
        wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
 | 
			
		||||
        wan_model, pipe = load_t2v_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
 | 
			
		||||
    wan_model._model_file_name = new_transformer_filename
 | 
			
		||||
    kwargs = { "extraModelsToQuantize": None}   
 | 
			
		||||
    if profile == 2 or profile == 4:
 | 
			
		||||
@ -1926,7 +1927,7 @@ def load_models(model_filename):
 | 
			
		||||
        #     kwargs["partialPinning"] = True
 | 
			
		||||
    elif profile == 3:
 | 
			
		||||
        kwargs["budgets"] = { "*" : "70%" }
 | 
			
		||||
    offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = default_dtype, **kwargs)  
 | 
			
		||||
    offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)  
 | 
			
		||||
    if len(args.gpu) > 0:
 | 
			
		||||
        torch.set_default_device(args.gpu)
 | 
			
		||||
    transformer_filename = new_transformer_filename
 | 
			
		||||
@ -2410,6 +2411,7 @@ def generate_video(
 | 
			
		||||
):
 | 
			
		||||
    global wan_model, offloadobj, reload_needed
 | 
			
		||||
    gen = get_gen_info(state)
 | 
			
		||||
    torch.set_grad_enabled(False) 
 | 
			
		||||
 | 
			
		||||
    file_list = gen["file_list"]
 | 
			
		||||
    prompt_no = gen["prompt_no"]
 | 
			
		||||
@ -2574,6 +2576,7 @@ def generate_video(
 | 
			
		||||
    if seed == None or seed <0:
 | 
			
		||||
        seed = random.randint(0, 999999999)
 | 
			
		||||
 | 
			
		||||
    torch.set_grad_enabled(False) 
 | 
			
		||||
    global save_path
 | 
			
		||||
    os.makedirs(save_path, exist_ok=True)
 | 
			
		||||
    abort = False
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user