-
Notifications
You must be signed in to change notification settings - Fork 314
Deepseekv3.2 #1246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
blueswhen
wants to merge
9
commits into
main
Choose a base branch
from
deepseekv3.2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Deepseekv3.2 #1246
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
191 changes: 191 additions & 0 deletions
191
lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| import dataclasses | ||
| import torch | ||
| from typing import TYPE_CHECKING, Tuple | ||
|
|
||
| from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState | ||
| from lightllm.utils.dist_utils import get_current_device_id | ||
|
|
||
| if TYPE_CHECKING: | ||
| from lightllm.common.basemodel.infer_struct import InferStateInfo | ||
|
|
||
|
|
||
| class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): | ||
| def __init__(self, model): | ||
| super().__init__(model=model) | ||
| device = get_current_device_id() | ||
| self.ragged_mem_buffers = [ | ||
| torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) | ||
| for _ in range(2) | ||
| ] | ||
|
|
||
| def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": | ||
| return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) | ||
|
|
||
| def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": | ||
| return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): | ||
| ks: torch.Tensor = None | ||
| ke: torch.Tensor = None | ||
| lengths: torch.Tensor = None | ||
| ragged_mem_index: torch.Tensor = None | ||
|
|
||
| def init_state(self): | ||
| self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend | ||
| self.ragged_mem_index = torch.empty( | ||
| self.infer_state.total_token_num, | ||
| dtype=torch.int32, | ||
| device=get_current_device_id(), | ||
| ) | ||
| from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke | ||
|
|
||
| self.ks, self.ke, self.lengths = gen_nsa_ks_ke( | ||
| b_seq_len=self.infer_state.b_seq_len, | ||
| b_q_seq_len=self.infer_state.b_q_seq_len, | ||
| b_req_idx=self.infer_state.b_req_idx, | ||
| req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, | ||
| q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, | ||
| ragged_mem_index=self.ragged_mem_index, | ||
| hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, | ||
| ) | ||
| return | ||
|
|
||
| def prefill_att( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| att_control: AttControl = AttControl(), | ||
| alloc_func=torch.empty, | ||
| ) -> torch.Tensor: | ||
| assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" | ||
| assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" | ||
| return self._nsa_prefill_att(q=q, att_control=att_control) | ||
|
|
||
| def _nsa_prefill_att( | ||
| self, | ||
| q: torch.Tensor, | ||
| att_control: AttControl, | ||
| ) -> torch.Tensor: | ||
| import flash_mla | ||
|
|
||
| nsa_dict = att_control.nsa_prefill_dict | ||
| topk_indices = nsa_dict["topk_indices"] | ||
| softmax_scale = nsa_dict["softmax_scale"] | ||
| kv_lora_rank = nsa_dict["kv_lora_rank"] | ||
| layer_index = nsa_dict["layer_index"] | ||
| topk_mem_indices = nsa_dict["topk_mem_indices"] | ||
| prefill_cache_kv = nsa_dict["prefill_cache_kv"] | ||
|
|
||
| if self.infer_state.prefix_total_token_num > 0: | ||
| kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices( | ||
| layer_index=layer_index, | ||
| topk_indices=topk_mem_indices, | ||
| prefill_mem_index=self.infer_state.mem_index, | ||
| prefill_cache_kv=prefill_cache_kv, | ||
| ) | ||
| else: | ||
| kv = prefill_cache_kv | ||
|
|
||
| if topk_indices.ndim == 2: | ||
| topk_indices = topk_indices.unsqueeze(1) | ||
|
|
||
| mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( | ||
| q=q, | ||
| kv=kv, | ||
| indices=topk_indices, | ||
| sm_scale=softmax_scale, | ||
| d_v=kv_lora_rank, | ||
| ) | ||
| return mla_out | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): | ||
| ks: torch.Tensor = None | ||
| ke: torch.Tensor = None | ||
| lengths: torch.Tensor = None | ||
| ragged_mem_index: torch.Tensor = None | ||
| flashmla_sched_meta: object = None | ||
|
|
||
| def init_state(self): | ||
| self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend | ||
| model = self.backend.model | ||
| use_cuda_graph = ( | ||
| self.infer_state.batch_size <= model.graph_max_batch_size | ||
| and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch | ||
| ) | ||
|
|
||
| if use_cuda_graph: | ||
| self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] | ||
| else: | ||
| self.ragged_mem_index = torch.empty( | ||
| self.infer_state.total_token_num, | ||
| dtype=torch.int32, | ||
| device=get_current_device_id(), | ||
| ) | ||
|
|
||
| from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke | ||
|
|
||
| self.ks, self.ke, self.lengths = gen_nsa_ks_ke( | ||
| b_seq_len=self.infer_state.b_seq_len, | ||
| b_q_seq_len=self.infer_state.b_q_seq_len, | ||
| b_req_idx=self.infer_state.b_req_idx, | ||
| req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, | ||
| q_token_num=self.infer_state.b_seq_len.shape[0], | ||
| ragged_mem_index=self.ragged_mem_index, | ||
| hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, | ||
| ) | ||
| import flash_mla | ||
|
|
||
| self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() | ||
| return | ||
|
|
||
| def decode_att( | ||
| self, | ||
| q: Tuple[torch.Tensor, torch.Tensor], | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| att_control: AttControl = AttControl(), | ||
| alloc_func=torch.empty, | ||
| ) -> torch.Tensor: | ||
| assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" | ||
| assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" | ||
| return self._nsa_decode_att(q=q, kv=k, att_control=att_control) | ||
|
|
||
| def _nsa_decode_att( | ||
| self, | ||
| q: Tuple[torch.Tensor, torch.Tensor], | ||
| kv: torch.Tensor, | ||
| att_control: AttControl, | ||
| ) -> torch.Tensor: | ||
| import flash_mla | ||
|
|
||
| nsa_dict = att_control.nsa_decode_dict | ||
| topk_mem_indices = nsa_dict["topk_mem_indices"] | ||
| softmax_scale = nsa_dict["softmax_scale"] | ||
| kv_lora_rank = nsa_dict["kv_lora_rank"] | ||
|
|
||
| if topk_mem_indices.ndim == 2: | ||
| topk_mem_indices = topk_mem_indices.unsqueeze(1) | ||
| assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" | ||
|
|
||
| q_nope, q_rope = q | ||
| q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() | ||
|
|
||
| o_tensor, _ = flash_mla.flash_mla_with_kvcache( | ||
| q=q_all, | ||
| k_cache=kv, | ||
| block_table=None, | ||
| cache_seqlens=None, | ||
| head_dim_v=kv_lora_rank, | ||
| tile_scheduler_metadata=self.flashmla_sched_meta, | ||
| num_splits=None, | ||
| softmax_scale=softmax_scale, | ||
| causal=False, | ||
| is_fp8_kvcache=True, | ||
| indices=topk_mem_indices.contiguous(), | ||
| ) | ||
| return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| import torch | ||
| from typing import Any | ||
|
|
||
| from .deepseek2_mem_manager import Deepseek2MemoryManager | ||
|
|
||
|
|
||
| class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): | ||
| kv_nope_dim = 512 | ||
| kv_rope_dim = 64 | ||
| # 576 = 512 + 64 | ||
| kv_head_dim = kv_nope_dim + kv_rope_dim | ||
|
|
||
| quant_group_size = 128 | ||
| # 4 = 512 / 128 | ||
| quant_group_num = kv_nope_dim // quant_group_size | ||
| # 4 * 4 = quant_group_num * fp32 | ||
| # 64 * 2 = kv_rope_dim * bfloat16 | ||
| # 656 bytes = 512 + (4 * 4) + (64 * 2) | ||
| flashmla_bytes_per_token = kv_nope_dim + quant_group_num * 4 + kv_rope_dim * 2 | ||
|
|
||
| indexer_head_dim = 128 | ||
| # 128 + 4 = indexer_head_dim + fp32 | ||
| # 132 bytes = 128 + 4 | ||
| indexer_bytes_per_token = indexer_head_dim + 4 | ||
|
|
||
| # Merged per-token layout in kv_buffer: | ||
| # [flashmla (656 bytes) | indexer (132 bytes) | padding (12 bytes)] = 800 bytes total | ||
| # Padded to 16-byte alignment so that FlashMLA CUDA kernel vectorized | ||
| # loads (uint4 = 16 bytes) always access aligned addresses. | ||
| # This allows parent class PD/move/page/p2p methods to transfer both | ||
| # flashmla and indexer data together in a single kv_buffer. | ||
| _ALIGNMENT = 16 | ||
| total_bytes_per_token = ( | ||
| (flashmla_bytes_per_token + indexer_bytes_per_token + _ALIGNMENT - 1) // _ALIGNMENT * _ALIGNMENT | ||
| ) | ||
|
|
||
| def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): | ||
| assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" | ||
| self.prefill_dtype = dtype | ||
| super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction) | ||
|
|
||
| def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): | ||
| from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( | ||
| destindex_copy_kv_flashmla_fp8, | ||
| ) | ||
|
|
||
| rope_dim = 64 | ||
| kv_lora_rank = kv.shape[2] - rope_dim | ||
| assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" | ||
|
|
||
| o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) | ||
| o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) | ||
| o_rope = self.kv_buffer[layer_index][:, :, 528 : self.flashmla_bytes_per_token].view(torch.bfloat16) | ||
| destindex_copy_kv_flashmla_fp8( | ||
| kv[:, :, :kv_lora_rank], | ||
| kv[:, :, kv_lora_rank:], | ||
| mem_index, | ||
| o_nope, | ||
| o_scale, | ||
| o_rope, | ||
| ) | ||
|
|
||
| def get_att_input_params(self, layer_index: int) -> Any: | ||
| # Return an as_strided view: shape (N, 1, 1, 656) with stride (788, 656, 656, 1). | ||
| # FlashMLA CUDA kernel uses stride(0) for block offset and stride(1) for row offset, | ||
| # so it correctly reads 656 contiguous bytes per token even though tokens are 788 bytes apart. | ||
| # This avoids any data copy while passing FlashMLA's shape and stride checks. | ||
| buf = self.kv_buffer[layer_index] | ||
| return torch.as_strided( | ||
| buf, | ||
| size=(buf.shape[0], 1, 1, self.flashmla_bytes_per_token), | ||
| stride=(buf.stride(0), self.flashmla_bytes_per_token, self.flashmla_bytes_per_token, 1), | ||
| ) | ||
|
|
||
| def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: | ||
| indexer_end = self.flashmla_bytes_per_token + self.indexer_bytes_per_token | ||
| return self.kv_buffer[layer_index][:, :, self.flashmla_bytes_per_token : indexer_end] | ||
|
|
||
| def get_prefill_kv_cache_and_remap_indices( | ||
| self, | ||
| layer_index: int, | ||
| topk_indices: torch.Tensor, | ||
| prefill_mem_index: torch.Tensor, | ||
| prefill_cache_kv: torch.Tensor, | ||
| ): | ||
| from lightllm.models.deepseek3_2.triton_kernel.prefill_compact_kv_flashmla_fp8 import ( | ||
| get_prefill_kv_cache_and_remap_indices_triton, | ||
| ) | ||
|
|
||
| return get_prefill_kv_cache_and_remap_indices_triton( | ||
| packed_kv=self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token], | ||
| topk_mem_indices=topk_indices, | ||
| prefill_mem_index=prefill_mem_index, | ||
| prefill_cache_kv=prefill_cache_kv, | ||
| prefill_dtype=self.prefill_dtype, | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To reduce the final Docker image size, it's a good practice to clean up build-time dependencies and source files within the same
RUNlayer. After installingFlashMLA, the cloned repository at/root/FlashMLAis no longer needed and can be removed.