mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			51 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			51 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import av
 | 
						|
import torch
 | 
						|
import io
 | 
						|
import numpy as np
 | 
						|
 | 
						|
 | 
						|
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
 | 
						|
    container = av.open(output_file, "w", format="mp4")
 | 
						|
    try:
 | 
						|
        stream = container.add_stream(
 | 
						|
            "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
 | 
						|
        )
 | 
						|
        stream.height = image_array.shape[0]
 | 
						|
        stream.width = image_array.shape[1]
 | 
						|
        av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
 | 
						|
            format="yuv420p"
 | 
						|
        )
 | 
						|
        container.mux(stream.encode(av_frame))
 | 
						|
        container.mux(stream.encode())
 | 
						|
    finally:
 | 
						|
        container.close()
 | 
						|
 | 
						|
 | 
						|
def _decode_single_frame(video_file):
 | 
						|
    container = av.open(video_file)
 | 
						|
    try:
 | 
						|
        stream = next(s for s in container.streams if s.type == "video")
 | 
						|
        frame = next(container.decode(stream))
 | 
						|
    finally:
 | 
						|
        container.close()
 | 
						|
    return frame.to_ndarray(format="rgb24")
 | 
						|
 | 
						|
 | 
						|
def compress(image: torch.Tensor, crf=29):
 | 
						|
    if crf == 0:
 | 
						|
        return image
 | 
						|
 | 
						|
    image_array = (
 | 
						|
        (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
 | 
						|
        .byte()
 | 
						|
        .cpu()
 | 
						|
        .numpy()
 | 
						|
    )
 | 
						|
    with io.BytesIO() as output_file:
 | 
						|
        _encode_single_frame(output_file, image_array, crf)
 | 
						|
        video_bytes = output_file.getvalue()
 | 
						|
    with io.BytesIO(video_bytes) as video_file:
 | 
						|
        image_array = _decode_single_frame(video_file)
 | 
						|
    tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
 | 
						|
    return tensor
 |