diff --git a/shared/radial_attention/attention.py b/shared/radial_attention/attention.py new file mode 100644 index 0000000..ab4a5a2 --- /dev/null +++ b/shared/radial_attention/attention.py @@ -0,0 +1,49 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange +from .attn_mask import RadialAttention, MaskMap + +def fill_radial_cache(radial_cache, nb_layers, lat_t, lat_h, lat_w): + MaskMap._log_mask = None + + for i in range(nb_layers): + radial_cache[i] = WanSparseAttnProcessor2_0(i, lat_t, lat_h, lat_w) + +class WanSparseAttnProcessor2_0: + mask_map = None + dense_timestep = 0 + dense_block = 0 + decay_factor = 1.0 + sparse_type = "radial" # default to radial attention, can be changed to "dense" for dense attention + use_sage_attention = True + + def __init__(self, layer_idx, lat_t, lat_h, lat_w): + self.layer_idx = layer_idx + self.mask_map = MaskMap(video_token_num=lat_t * lat_h * lat_w // 4 , num_frame=lat_t) + def __call__( + self, + qkv_list, + timestep_no = 0, + ) -> torch.Tensor: + query, key, value = qkv_list + + batch_size = query.shape[0] + # transform (batch_size, seq_len, num_heads, head_dim) to (seq_len * batch_size, num_heads, head_dim) + query = rearrange(query, "b s h d -> (b s) h d") + key = rearrange(key, "b s h d -> (b s) h d") + value = rearrange(value, "b s h d -> (b s) h d") + if timestep_no < self.dense_timestep or self.layer_idx < self.dense_block or self.sparse_type == "dense": + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + else: + # apply radial attention + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="radial", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + # transform back to (batch_size, num_heads, seq_len, head_dim) + hidden_states = rearrange(hidden_states, "(b s) h d -> b s h d", b=batch_size) + + return hidden_states diff --git a/shared/radial_attention/attn_mask.py b/shared/radial_attention/attn_mask.py new file mode 100644 index 0000000..4c29f1c --- /dev/null +++ b/shared/radial_attention/attn_mask.py @@ -0,0 +1,379 @@ +import torch +# import flashinfer +import matplotlib.pyplot as plt +# from sparse_sageattn import sparse_sageattn +from einops import rearrange, repeat +from sageattention import sageattn +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor: + assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64" + assert mask.shape[0] == mask.shape[1], "Input mask must be square." + + if block_size == 128: + if arch == "sm90": + new_mask = torch.repeat_interleave(mask, 2, dim=0) + else: + new_mask = torch.repeat_interleave(mask, 2, dim=1) + + elif block_size == 64: + if arch == "sm90": + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row, num_col // 2, 2) + new_mask = torch.max(reshaped_mask, dim=2).values + else: + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row // 2, 2, num_col) + new_mask = torch.max(reshaped_mask, dim=1).values + + return new_mask + +def get_indptr_from_mask(mask, query): + # query shows the device of the indptr + # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension, + # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension. + # The first element is always 0, and the last element is the number of blocks in the row dimension. + # The rest of the elements are the number of blocks in each row. + # the mask is already a block sparse mask + indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32) + indptr[0] = 0 + row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row] + indptr[1:] = torch.cumsum(row_counts, dim=0) + return indptr + +def get_indices_from_mask(mask, query): + # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension, + # shape `(nnz,),` where `nnz` is the number of non-zero blocks. + # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension. + nonzero_indices = torch.nonzero(mask) + indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device) + return indices + +def shrinkMaskStrict(mask, block_size=128): + seqlen = mask.shape[0] + block_num = seqlen // block_size + mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size) + col_densities = mask.sum(dim = 1) / block_size + # we want the minimum non-zero column density in the block + non_zero_densities = col_densities > 0 + high_density_cols = col_densities > 1/3 + frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) + block_mask = frac_high_density_cols > 0.6 + block_mask[0:0] = True + block_mask[-1:-1] = True + return block_mask + +def pad_qkv(input_tensor, block_size=128): + """ + Pad the input tensor to be a multiple of the block size. + input shape: (seqlen, num_heads, hidden_dim) + """ + seqlen, num_heads, hidden_dim = input_tensor.shape + # Calculate the necessary padding + padding_length = (block_size - (seqlen % block_size)) % block_size + # Create a padded tensor with zeros + padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype) + # Copy the original tensor into the padded tensor + padded_tensor[:seqlen, :, :] = input_tensor + + return padded_tensor + +def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + group = dist.bit_length() + threshold = 128 # hardcoded threshold for now, which is equal to block-size + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group + if decay_length >= threshold: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + + split_factor = int(threshold / decay_length) + modular = dist % split_factor + if modular == 0: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + +def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + if model_type == "wan": + if dist < 1: + return token_per_frame + if dist == 1: + return token_per_frame // 2 + elif model_type == "hunyuan": + if dist <= 1: + return token_per_frame + else: + raise ValueError(f"Unknown model type: {model_type}") + group = dist.bit_length() + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor + threshold = block_size + if decay_length >= threshold: + return decay_length + else: + return threshold + +def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None): + """ + A more memory friendly version, we generate the attention mask of each frame pair at a time, + shrinks it, and stores it into the final result + """ + final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool) + token_per_frame = video_token_num // num_frame + video_text_border = video_token_num // block_size + + col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1) + row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1) + final_log_mask[video_text_border:] = True + final_log_mask[:, video_text_border:] = True + for i in range(num_frame): + for j in range(num_frame): + local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + if j == 0 and model_type == "wan": # this is attention sink + local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type) + local_mask = torch.abs(col_indices - row_indices) <= window_width + split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query) + local_mask = torch.logical_and(local_mask, split_mask) + + remainder_row = (i * token_per_frame) % block_size + remainder_col = (j * token_per_frame) % block_size + # get the padded size + all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size + all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size + padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool) + padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask + # shrink the mask + block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size) + # set the block mask to the final log mask + block_row_start = (i * token_per_frame) // block_size + block_col_start = (j * token_per_frame) // block_size + block_row_end = block_row_start + block_mask.shape[0] + block_col_end = block_col_start + block_mask.shape[1] + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or( + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask) + print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}") + return final_log_mask + +class MaskMap: + _log_mask = None + + def __init__(self, video_token_num=25440, num_frame=16): + self.video_token_num = video_token_num + self.num_frame = num_frame + + def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None): + if MaskMap._log_mask is None: + MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool) + MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size) + return MaskMap._log_mask + +def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128): + if video_mask.all(): + # dense case + kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0] + output_video = sageattn( + query[:mask_map.video_token_num, :, :].unsqueeze(0), + key[:kv_border, :, :].unsqueeze(0), + value[:kv_border, :, :].unsqueeze(0), + tensor_layout="NHD", + )[0] + + if pre_defined_mask is not None: + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + return torch.cat([output_video, output_text], dim=0) + else: + return output_video + + # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first + query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d") + key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d") + value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d") + arch = get_cuda_arch_versions()[query.device.index] + converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1]) + + converted_mask = converted_mask.to(torch.int8) + if pre_defined_mask is None: + # wan case + output = block_sparse_sage2_attn_cuda( + query_hnd[:, :, :mask_map.video_token_num, :], + key_hnd[:, :, :mask_map.video_token_num, :], + value_hnd[:, :, :mask_map.video_token_num, :], + mask_id=converted_mask, + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output = rearrange(output, "b h s d -> s (b h) d", b=1) + return output + + query_video = query_hnd[:, :, :mask_map.video_token_num, :] + key_video = key_hnd + value_video = value_hnd + kv_border = (pre_defined_mask[0].sum() + 63) // 64 + converted_mask[:, :, :, kv_border:] = False + output_video = block_sparse_sage2_attn_cuda( + query_video, + key_video, + value_video, + mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(), + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1) + + # gt = sparse_sageattn( + # query_video, + # key_video, + # value_video, + # mask_id=None, + # is_causal=False, + # tensor_layout="HND", + # )[0] + + + + # import pdb; pdb.set_trace() + + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + + return torch.cat([output_video, output_text], dim=0) + + +def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None): + if pre_defined_mask is not None: + video_video_o, video_video_o_lse = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :], + return_lse=True + ) + # perform non-causal flashinfer on the text tokens + video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache( + q=query[:mask_map.video_token_num, :, :], + k=key[mask_map.video_token_num:, :, :], + v=value[mask_map.video_token_num:, :, :], + causal=False, + return_lse=True, + custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:] + ) + + # merge the two results + o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse) + + o_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key, + v=value, + causal=False, + return_lse=False, + custom_mask=pre_defined_mask[mask_map.video_token_num:, :] + ) + + return torch.cat([o_video, o_text], dim=0) + else: + o = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :] + ) + return o + +def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False): + orig_seqlen, num_head, hidden_dim = query.shape + + if sparsity_type == "dense": + video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool) + else: + video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None + + backend = "sparse_sageattn" if use_sage_attention else "flashinfer" + + if backend == "flashinfer": + video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size] + # perform block-sparse attention on the video tokens + workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8) + bsr_wrapper = flashinfer.BlockSparseAttentionWrapper( + workspace_buffer, + backend="fa2", + ) + + indptr = get_indptr_from_mask(video_mask, query) + indices = get_indices_from_mask(video_mask, query) + + bsr_wrapper.plan( + indptr=indptr, + indices=indices, + M=mask_map.video_token_num, + N=mask_map.video_token_num, + R=block_size, + C=block_size, + num_qo_heads=num_head, + num_kv_heads=num_head, + head_dim=hidden_dim, + q_data_type=query.dtype, + kv_data_type=key.dtype, + o_data_type=query.dtype, + ) + + return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper) + elif backend == "sparse_sageattn": + return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size) + +if __name__ == "__main__": + query = torch.randn(1, 2, 4, 64).cuda() + # mask = torch.tensor([ + # [True, False, True, False], + # [False, True, False, True], + # [True, False, False, True], + # [False, True, True, False] + # ], dtype=torch.bool) + # indices = get_indices_from_mask(mask, query) + # indptr = get_indptr_from_mask(mask, query) + # print("Indices: ", indices) + # print("Indptr: ", indptr) + video_token_num = 3840 * 30 + num_frame = 30 + token_per_frame = video_token_num / num_frame + padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128 + print("padded: ", padded_video_token_num) + temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan") + plt.figure(figsize=(10, 8), dpi=500) + + plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot') + plt.colorbar() + plt.title("Temporal Mask") + + plt.savefig("temporal_mask.png", + dpi=300, + bbox_inches='tight', + pad_inches=0.1) + + plt.close() + # save the mask tensor + torch.save(temporal_mask, "temporal_mask.pt") diff --git a/shared/radial_attention/inference.py b/shared/radial_attention/inference.py new file mode 100644 index 0000000..04ef186 --- /dev/null +++ b/shared/radial_attention/inference.py @@ -0,0 +1,41 @@ +import torch +from diffusers.models.attention_processor import Attention +from diffusers.models.attention import AttentionModuleMixin +from .attention import WanSparseAttnProcessor +from .attn_mask import MaskMap + +def setup_radial_attention( + pipe, + height, + width, + num_frames, + dense_layers=0, + dense_timesteps=0, + decay_factor=1.0, + sparsity_type="radial", + use_sage_attention=False, +): + + num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0]) + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + frame_size = int(height // mod_value) * int(width // mod_value) + + AttnModule = WanSparseAttnProcessor + AttnModule.dense_block = dense_layers + AttnModule.dense_timestep = dense_timesteps + AttnModule.mask_map = MaskMap(video_token_num=frame_size * num_frames, num_frame=num_frames) + AttnModule.decay_factor = decay_factor + AttnModule.sparse_type = sparsity_type + AttnModule.use_sage_attention = use_sage_attention + + print(f"Replacing Wan attention with {sparsity_type} attention") + print(f"video token num: {AttnModule.mask_map.video_token_num}, num frames: {num_frames}") + print(f"dense layers: {dense_layers}, dense timesteps: {dense_timesteps}, decay factor: {decay_factor}") + + for layer_idx, m in enumerate(pipe.transformer.blocks): + m.attn1.processor.layer_idx = layer_idx + + for _, m in pipe.transformer.named_modules(): + if isinstance(m, AttentionModuleMixin) and hasattr(m.processor, 'layer_idx'): + layer_idx = m.processor.layer_idx + m.set_processor(AttnModule(layer_idx)) \ No newline at end of file diff --git a/shared/radial_attention/sparse_transformer.py b/shared/radial_attention/sparse_transformer.py new file mode 100644 index 0000000..96820d0 --- /dev/null +++ b/shared/radial_attention/sparse_transformer.py @@ -0,0 +1,424 @@ +# borrowed from svg-project/Sparse-VideoGen + +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel +from diffusers import WanPipeline +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +logger = logging.get_logger(__name__) +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +import torch.distributed as dist + +try: + from xfuser.core.distributed import ( + get_ulysses_parallel_world_size, + get_ulysses_parallel_rank, + get_sp_group + ) +except: + pass + +class WanTransformerBlock_Sparse(WanTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + numeral_timestep: Optional[int] = None, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, numerical_timestep=numeral_timestep) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states).contiguous() + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + +class WanTransformer3DModel_Sparse(WanTransformer3DModel): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + numeral_timestep: Optional[int] = None, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # split video latents on dim TS + hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()] + rotary_emb = ( + torch.chunk(rotary_emb[0], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + torch.chunk(rotary_emb[1], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + ) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, numeral_timestep=numeral_timestep + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + numeral_timestep=numeral_timestep, + ) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +class WanPipeline_Sparse(WanPipeline): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + +def replace_sparse_forward(): + WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward + WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward + WanPipeline.__call__ = WanPipeline_Sparse.__call__ \ No newline at end of file diff --git a/shared/radial_attention/utils.py b/shared/radial_attention/utils.py new file mode 100644 index 0000000..92b3c3e --- /dev/null +++ b/shared/radial_attention/utils.py @@ -0,0 +1,16 @@ +import os +import random +import numpy as np +import torch + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False \ No newline at end of file