From ab29bb281c2bcc1ccdf2138f0ecd3cc7f8f203fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Wed, 25 Mar 2026 12:54:33 +0800 Subject: [PATCH] feat: fp8 dsa support --- docker/Dockerfile | 6 + .../common/basemodel/attention/__init__.py | 1 + .../basemodel/attention/create_utils.py | 4 + .../basemodel/attention/nsa/__init__.py | 8 + .../attention/nsa/flashmla_sparse.py | 4 +- .../attention/nsa/fp8_flashmla_sparse.py | 195 ++++++++++++++++++ .../common/kv_cache_mem_manager/__init__.py | 2 + .../deepseek3_2mem_manager.py | 3 + ...oken_group_quant_deepseek3_2mem_manager.py | 83 ++++++++ .../common/kv_cache_mem_manager/mem_utils.py | 14 +- .../layer_infer/transformer_layer_infer.py | 30 +-- .../destindex_copy_kv_flashmla_fp8.py | 127 ++++++++++++ .../prefill_compact_kv_flashmla_fp8.py | 190 +++++++++++++++++ lightllm/server/api_cli.py | 4 +- lightllm/server/core/objs/start_args_type.py | 2 +- 15 files changed, 653 insertions(+), 20 deletions(-) create mode 100644 lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py create mode 100644 lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py create mode 100644 lightllm/models/deepseek3_2/triton_kernel/prefill_compact_kv_flashmla_fp8.py diff --git a/docker/Dockerfile b/docker/Dockerfile index e766107ae..439ecddb3 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04 ARG PYTHON_VERSION=3.10 ARG MAMBA_VERSION=24.7.1-0 ARG VLLM_VERSION=0.16.0 +ARG FLASH_MLA_REF=47c35a7 ARG TARGETPLATFORM ARG ENABLE_DEEPEP=1 ARG ENABLE_NIXL=1 @@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt RUN pip install -U pip RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm==${VLLM_VERSION} +RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \ + cd /root/FlashMLA && \ + git checkout ${FLASH_MLA_REF} && \ + git submodule update --init --recursive && \ + FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/* diff --git a/lightllm/common/basemodel/attention/__init__.py b/lightllm/common/basemodel/attention/__init__.py index 0eea52cc8..10cd3b086 100644 --- a/lightllm/common/basemodel/attention/__init__.py +++ b/lightllm/common/basemodel/attention/__init__.py @@ -12,6 +12,7 @@ # NSA backend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend from .create_utils import ( get_prefill_att_backend_class, diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 3ba16e218..2c4a34d32 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -15,6 +15,7 @@ from .flashinfer.fp import FlashInferAttBackend from .flashinfer.mla import MlaFlashInferAttBackend from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend +from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend logger = init_logger(__name__) @@ -56,6 +57,9 @@ "flashmla_sparse": NsaFlashMlaSparseAttBackend, # Future backends: "fa3", "tilelang", "aiter" }, + "fp8kv_dsa": { + "flashmla_sparse": NsaFlashMlaFp8SparseAttBackend, + }, } diff --git a/lightllm/common/basemodel/attention/nsa/__init__.py b/lightllm/common/basemodel/attention/nsa/__init__.py index 11a1ebfdc..f9db52dc2 100644 --- a/lightllm/common/basemodel/attention/nsa/__init__.py +++ b/lightllm/common/basemodel/attention/nsa/__init__.py @@ -5,9 +5,17 @@ NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState, ) +from .fp8_flashmla_sparse import ( + NsaFlashMlaFp8SparseAttBackend, + NsaFlashMlaFp8SparsePrefillAttState, + NsaFlashMlaFp8SparseDecodeAttState, +) __all__ = [ "NsaFlashMlaSparseAttBackend", "NsaFlashMlaSparsePrefillAttState", "NsaFlashMlaSparseDecodeAttState", + "NsaFlashMlaFp8SparseAttBackend", + "NsaFlashMlaFp8SparsePrefillAttState", + "NsaFlashMlaFp8SparseDecodeAttState", ] diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index c43927f37..673b5896d 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -165,7 +165,7 @@ def _nsa_decode_att( from sgl_kernel.flash_attn import flash_attn_with_kvcache nsa_dict = att_control.nsa_decode_dict - topk_indices = nsa_dict["topk_indices"] + topk_mem_indices = nsa_dict["topk_mem_indices"] softmax_scale = nsa_dict["softmax_scale"] kv_lora_rank = nsa_dict["kv_lora_rank"] qk_rope_head_dim = nsa_dict["qk_rope_head_dim"] @@ -181,7 +181,7 @@ def _nsa_decode_att( k_cache=k_rope, v_cache=kv_nope, qv=q_nope, - page_table=topk_indices, + page_table=topk_mem_indices, cache_seqlens=self.nsa_cache_seqlens, cu_seqlens_q=self.infer_state.b1_cu_q_seq_len, cu_seqlens_k_new=self.nsa_cu_seqlens_k_new, diff --git a/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py new file mode 100644 index 000000000..3972e039b --- /dev/null +++ b/lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py @@ -0,0 +1,195 @@ +import dataclasses +import torch +from typing import TYPE_CHECKING, Tuple + +from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState +from lightllm.utils.dist_utils import get_current_device_id + +if TYPE_CHECKING: + from lightllm.common.basemodel.infer_struct import InferStateInfo + + +class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend): + def __init__(self, model): + super().__init__(model=model) + device = get_current_device_id() + self.ragged_mem_buffers = [ + torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device) + for _ in range(2) + ] + + def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState": + return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState": + return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num, + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + return + + def prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention" + assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required" + return self._nsa_prefill_att(q=q, packed_kv=k, att_control=att_control) + + def _nsa_prefill_att( + self, + q: torch.Tensor, + packed_kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_prefill_dict + topk_indices = nsa_dict["topk_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + topk_mem_indices = nsa_dict["topk_mem_indices"] + prefill_cache_kv = nsa_dict["prefill_cache_kv"] + + if self.infer_state.prefix_total_token_num > 0: + kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices( + packed_kv=packed_kv, + topk_indices=topk_mem_indices, + prefill_mem_index=self.infer_state.mem_index, + prefill_cache_kv=prefill_cache_kv, + ) + else: + kv = prefill_cache_kv + + if topk_indices.ndim == 2: + topk_indices = topk_indices.unsqueeze(1) + + mla_out, _, _ = flash_mla.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=topk_indices, + sm_scale=softmax_scale, + d_v=kv_lora_rank, + ) + return mla_out + + +@dataclasses.dataclass +class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState): + ks: torch.Tensor = None + ke: torch.Tensor = None + lengths: torch.Tensor = None + ragged_mem_index: torch.Tensor = None + flashmla_sched_meta: object = None + + def init_state(self): + self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend + model = self.backend.model + use_cuda_graph = ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ) + + if use_cuda_graph: + self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index] + else: + self.ragged_mem_index = torch.empty( + self.infer_state.total_token_num, + dtype=torch.int32, + device=get_current_device_id(), + ) + + from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke + + self.ks, self.ke, self.lengths = gen_nsa_ks_ke( + b_seq_len=self.infer_state.b_seq_len, + b_q_seq_len=self.infer_state.b_q_seq_len, + b_req_idx=self.infer_state.b_req_idx, + req_to_token_index=self.infer_state.req_manager.req_to_token_indexs, + q_token_num=self.infer_state.b_seq_len.shape[0], + ragged_mem_index=self.ragged_mem_index, + hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, + ) + import flash_mla + + self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata() + return + + def decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ) -> torch.Tensor: + assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention" + assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required" + return self._nsa_decode_att(q=q, packed_kv=k, att_control=att_control) + + def _nsa_decode_att( + self, + q: Tuple[torch.Tensor, torch.Tensor], + packed_kv: torch.Tensor, + att_control: AttControl, + ) -> torch.Tensor: + import flash_mla + + nsa_dict = att_control.nsa_decode_dict + topk_mem_indices = nsa_dict["topk_mem_indices"] + softmax_scale = nsa_dict["softmax_scale"] + kv_lora_rank = nsa_dict["kv_lora_rank"] + + if topk_mem_indices.ndim == 2: + topk_mem_indices = topk_mem_indices.unsqueeze(1) + assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1" + + q_nope, q_rope = q + q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous() + kv = torch.as_strided( + packed_kv, + size=(packed_kv.shape[0], 1, 1, packed_kv.shape[-1]), + stride=(packed_kv.stride(0), packed_kv.shape[-1], packed_kv.shape[-1], packed_kv.stride(-1)), + ) + + o_tensor, _ = flash_mla.flash_mla_with_kvcache( + q=q_all, + k_cache=kv, + block_table=None, + cache_seqlens=None, + head_dim_v=kv_lora_rank, + tile_scheduler_metadata=self.flashmla_sched_meta, + softmax_scale=softmax_scale, + causal=False, + is_fp8_kvcache=True, + indices=topk_mem_indices, + ) + return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d] diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index bfccc8b48..79e75b348 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -3,6 +3,7 @@ from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager from .deepseek2_mem_manager import Deepseek2MemoryManager from .deepseek3_2mem_manager import Deepseek3_2MemoryManager +from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager @@ -13,6 +14,7 @@ "PPLINT8KVMemoryManager", "Deepseek2MemoryManager", "Deepseek3_2MemoryManager", + "FP8PerTokenGroupQuantDeepseek3_2MemoryManager", "FP8StaticPerHeadQuantMemManager", "FP8StaticPerTensorQuantMemManager", ] diff --git a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py index fbf9f88c8..66f37a16f 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek3_2mem_manager.py @@ -34,3 +34,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: def get_att_input_params(self, layer_index: int) -> Any: kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))] return kv + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + return self.kv_buffer[layer_index].view(dtype=torch.uint8)[:, :, -132:] diff --git a/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py new file mode 100644 index 000000000..b4464cd12 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/fp8_per_token_group_quant_deepseek3_2mem_manager.py @@ -0,0 +1,83 @@ +import torch +from typing import Any + +from .deepseek2_mem_manager import Deepseek2MemoryManager + + +class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager): + kv_nope_dim = 512 + kv_rope_dim = 64 + # 576 = 512 + 64 + kv_head_dim = kv_nope_dim + kv_rope_dim + + quant_group_size = 128 + # 4 = 512 / 128 + quant_group_num = kv_nope_dim // quant_group_size + # 4 * 4 = quant_group_num * fp32 + # 64 * 2 = kv_rope_dim * bfloat16 + # 656 bytes = 512 + (4 * 4) + (64 * 2) + flashmla_bytes_per_token = kv_nope_dim + quant_group_num * 4 + kv_rope_dim * 2 + + indexer_head_dim = 128 + # 128 + 4 = indexer_head_dim + fp32 + # 132 bytes = 128 + 4 + indexer_bytes_per_token = indexer_head_dim + 4 + + # 16-byte 对齐,满足FlashMLA的对齐要求 + alignment = 16 + total_bytes_per_token = ( + (flashmla_bytes_per_token + indexer_bytes_per_token + alignment - 1) // alignment * alignment + ) + + def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): + assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1" + self.prefill_dtype = dtype + super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction) + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import ( + destindex_copy_kv_flashmla_fp8, + ) + + rope_dim = 64 + kv_lora_rank = kv.shape[2] - rope_dim + assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}" + + o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn) + o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32) + o_rope = self.kv_buffer[layer_index][:, :, 528 : self.flashmla_bytes_per_token].view(torch.bfloat16) + destindex_copy_kv_flashmla_fp8( + kv[:, :, :kv_lora_rank], + kv[:, :, kv_lora_rank:], + mem_index, + o_nope, + o_scale, + o_rope, + ) + + def get_att_input_params(self, layer_index: int) -> Any: + return self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token] + + def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor: + begin = self.flashmla_bytes_per_token + end = begin + self.indexer_bytes_per_token + return self.kv_buffer[layer_index][:, :, begin:end] + + def get_prefill_kv_cache_and_remap_indices( + self, + packed_kv: torch.Tensor, + topk_indices: torch.Tensor, + prefill_mem_index: torch.Tensor, + prefill_cache_kv: torch.Tensor, + ): + from lightllm.models.deepseek3_2.triton_kernel.prefill_compact_kv_flashmla_fp8 import ( + get_prefill_kv_cache_and_remap_indices_triton, + ) + + return get_prefill_kv_cache_and_remap_indices_triton( + packed_kv=packed_kv, + topk_mem_indices=topk_indices, + prefill_mem_index=prefill_mem_index, + prefill_cache_kv=prefill_cache_kv, + prefill_dtype=self.prefill_dtype, + ) diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 36ca8646a..79ea44879 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -24,7 +24,12 @@ def select_mem_manager_class(): from lightllm.models import Deepseek3_2TpPartModel if issubclass(model_class, Deepseek3_2TpPartModel): - mem_class = Deepseek3_2MemoryManager + if get_env_start_args().llm_kv_type == "fp8kv_dsa": + from . import FP8PerTokenGroupQuantDeepseek3_2MemoryManager + + mem_class = FP8PerTokenGroupQuantDeepseek3_2MemoryManager + else: + mem_class = Deepseek3_2MemoryManager logger.info(f"Model kv cache using default, mem_manager class: {mem_class}") return mem_class @@ -55,4 +60,9 @@ def select_mem_manager_class(): @lru_cache(maxsize=None) def used_mem_manager_has_scale() -> bool: mem_class = select_mem_manager_class() - return mem_class in [PPLINT8KVMemoryManager, PPLINT4KVMemoryManager, FP8StaticPerHeadQuantMemManager, FP8StaticPerTensorQuantMemManager] + return mem_class in [ + PPLINT8KVMemoryManager, + PPLINT4KVMemoryManager, + FP8StaticPerHeadQuantMemManager, + FP8StaticPerTensorQuantMemManager, + ] diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index b00612017..58c544e82 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -1,12 +1,11 @@ import torch -from typing import Union +from typing import Any from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl -from lightllm.common.basemodel.attention.nsa import NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks @@ -74,9 +73,9 @@ def _context_attention_kernel( q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) q_all = torch.cat([q_nope, q_rope], dim=-1) - # 计算 topk_indices + # 计算 topk indices att_state = infer_state.prefill_att_state - topk_indices = self.indexer.get_indices( + topk_mem_indices, topk_indices = self.indexer._get_indices( hidden_states=infer_state.get_topk_indices_params["hidden_states"], q_lora=infer_state.get_topk_indices_params["q_lora"], infer_state=infer_state, @@ -89,7 +88,9 @@ def _context_attention_kernel( att_control = AttControl( nsa_prefill=True, nsa_prefill_dict={ + "topk_mem_indices": topk_mem_indices, "topk_indices": topk_indices, + "prefill_cache_kv": kv, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, }, @@ -114,9 +115,9 @@ def _token_attention_kernel( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) - # 计算 topk_indices + # 计算 topk mem indices att_state = infer_state.decode_att_state - topk_indices = self.indexer.get_indices( + topk_mem_indices, _ = self.indexer._get_indices( hidden_states=infer_state.get_topk_indices_params["hidden_states"], q_lora=infer_state.get_topk_indices_params["q_lora"], infer_state=infer_state, @@ -129,7 +130,8 @@ def _token_attention_kernel( att_control = AttControl( nsa_decode=True, nsa_decode_dict={ - "topk_indices": topk_indices, + "layer_index": self.layer_num_, + "topk_mem_indices": topk_mem_indices, "softmax_scale": self.softmax_scale, "kv_lora_rank": self.kv_lora_rank, "qk_rope_head_dim": self.qk_rope_head_dim, @@ -163,14 +165,14 @@ def __init__(self, layer_idx: int, network_config: dict, tp_world_size: int): self.tp_world_size_ = tp_world_size self.tp_index_n_heads = self.index_n_heads // self.tp_world_size_ - def get_indices( + def _get_indices( self, hidden_states: torch.Tensor, q_lora: torch.Tensor, infer_state: Deepseek2InferStateInfo, - att_state: Union[NsaFlashMlaSparsePrefillAttState, NsaFlashMlaSparseDecodeAttState], + att_state: Any, layer_weight: Deepseek3_2TransformerLayerWeight, - ) -> torch.Tensor: + ): q, k = self._get_q_k_bf16(hidden_states, q_lora, infer_state, layer_weight) @@ -195,7 +197,7 @@ def get_indices( K_fp8=k_fp8, K_scale=k_scale, DestLoc=infer_state.mem_index, - O_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + O_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), ) weights = layer_weight.weights_proj_.mm(hidden_states) * self.index_n_heads_scale @@ -211,7 +213,7 @@ def get_indices( mtp_step = get_env_start_args().mtp_step # Use efficient Triton kernel to extract FP8 keys and scales from buffer k_fp8_, k_scale_ = extract_indexer_ks( - I_buffer=infer_state.mem_manager.kv_buffer[self.layer_idx_].view(dtype=torch.uint8)[:, :, -132:], + I_buffer=infer_state.mem_manager.get_indexer_k_buffer(self.layer_idx_), b_seq_len=infer_state.b_seq_len, b_req_idx=infer_state.b_req_idx, req_to_token_indexs=infer_state.req_manager.req_to_token_indexs, @@ -236,12 +238,12 @@ def get_indices( # 将 topk index 转化为 mem index from ..triton_kernel.topk_index_to_mem_index import trans_topk_index_to_mem_index - b_topk_index = trans_topk_index_to_mem_index( + b_topk_mem_index = trans_topk_index_to_mem_index( topk_index=b_topk_index, ragged_mem_index=att_state.ragged_mem_index, ) - return b_topk_index + return b_topk_mem_index, b_topk_index @staticmethod def _rotate_activation(x: torch.Tensor) -> torch.Tensor: diff --git a/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py new file mode 100644 index 000000000..d5859c686 --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/destindex_copy_kv_flashmla_fp8.py @@ -0,0 +1,127 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _quant_scale(max_nope, fp8_max): + return tl.exp2(tl.ceil(tl.log2(tl.maximum(max_nope / fp8_max, 1e-4)))) + + +@triton.jit +def _fwd_kernel_destindex_copy_kv_flashmla_fp8( + KV_nope, + KV_rope, + Dest_loc, + O_nope, + O_scale, + O_rope, + stride_kv_nope_bs, + stride_kv_nope_h, + stride_kv_nope_d, + stride_kv_rope_bs, + stride_kv_rope_h, + stride_kv_rope_d, + stride_o_nope_bs, + stride_o_nope_h, + stride_o_nope_d, + stride_o_scale_bs, + stride_o_scale_h, + stride_o_scale_d, + stride_o_rope_bs, + stride_o_rope_h, + stride_o_rope_d, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + BLOCK_DMODEL_NOPE: tl.constexpr, + BLOCK_DMODEL_ROPE: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + dest_index = tl.load(Dest_loc + cur_index).to(tl.int64) + + offs_rope = tl.arange(0, BLOCK_DMODEL_ROPE) + + # This kernel is only used by the DeepSeek-V3.2 DSA FP8 path, which + # stores a single MQA-style KV head per token. Keep all accesses 1-D so + # Triton treats per-tile scales as scalars instead of 1-element blocks. + kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_rope + + kv_rope = tl.load(kv_rope_ptrs) + + o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_rope + tl.store(o_rope_ptrs, kv_rope) + + num_tiles = BLOCK_DMODEL_NOPE // GROUP_SIZE + for tile_idx in range(0, num_tiles): + offs_tile = tile_idx * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + kv_nope_tile_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_tile + kv_nope_tile = tl.load(kv_nope_tile_ptrs) + max_nope = tl.max(tl.abs(kv_nope_tile), axis=0) + kv_scale = _quant_scale(max_nope, FP8_MAX) + kv_nope_fp8 = tl.clamp(kv_nope_tile / kv_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv) + + o_nope_ptrs = ( + O_nope + + dest_index * stride_o_nope_bs + + (tile_idx * GROUP_SIZE) * stride_o_nope_d + + tl.arange(0, GROUP_SIZE) * stride_o_nope_d + ) + tl.store(o_nope_ptrs, kv_nope_fp8) + + o_scale_ptrs = O_scale + dest_index * stride_o_scale_bs + tile_idx * stride_o_scale_d + tl.store(o_scale_ptrs, kv_scale.to(tl.float32)) + return + + +@torch.no_grad() +def destindex_copy_kv_flashmla_fp8( + KV_nope: torch.Tensor, + KV_rope: torch.Tensor, + DestLoc: torch.Tensor, + O_nope: torch.Tensor, + O_scale: torch.Tensor, + O_rope: torch.Tensor, +): + seq_len = DestLoc.shape[0] + kv_nope_head_dim = KV_nope.shape[2] + kv_rope_head_dim = KV_rope.shape[2] + + assert kv_nope_head_dim == 512, f"Expected kv_nope_head_dim=512, got {kv_nope_head_dim}" + assert kv_rope_head_dim == 64, f"Expected kv_rope_head_dim=64, got {kv_rope_head_dim}" + assert O_nope.shape[2] == 512 + assert O_scale.shape[2] == 4 + assert O_rope.shape[2] == 64 + + _fwd_kernel_destindex_copy_kv_flashmla_fp8[(seq_len,)]( + KV_nope, + KV_rope, + DestLoc, + O_nope, + O_scale, + O_rope, + KV_nope.stride(0), + KV_nope.stride(1), + KV_nope.stride(2), + KV_rope.stride(0), + KV_rope.stride(1), + KV_rope.stride(2), + O_nope.stride(0), + O_nope.stride(1), + O_nope.stride(2), + O_scale.stride(0), + O_scale.stride(1), + O_scale.stride(2), + O_rope.stride(0), + O_rope.stride(1), + O_rope.stride(2), + FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, + FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + BLOCK_DMODEL_NOPE=512, + BLOCK_DMODEL_ROPE=64, + GROUP_SIZE=128, + num_warps=4, + num_stages=1, + ) + return diff --git a/lightllm/models/deepseek3_2/triton_kernel/prefill_compact_kv_flashmla_fp8.py b/lightllm/models/deepseek3_2/triton_kernel/prefill_compact_kv_flashmla_fp8.py new file mode 100644 index 000000000..f8c24e08a --- /dev/null +++ b/lightllm/models/deepseek3_2/triton_kernel/prefill_compact_kv_flashmla_fp8.py @@ -0,0 +1,190 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_prefill_row_table_kernel( + prefill_mem_index_ptr, + row_table_ptr, + prefill_token_num, +): + pid = tl.program_id(0) + if pid < prefill_token_num: + mem_index = tl.load(prefill_mem_index_ptr + pid) + tl.store(row_table_ptr + mem_index, pid) + + +@triton.jit +def _fill_compact_kv_kernel( + packed_nope_ptr, + packed_scale_ptr, + packed_rope_ptr, + unique_mem_index_ptr, + prefill_row_table_ptr, + prefill_kv_ptr, + compact_kv_ptr, + packed_nope_stride_s, + packed_nope_stride_d, + packed_scale_stride_s, + packed_scale_stride_d, + packed_rope_stride_s, + packed_rope_stride_d, + prefill_kv_stride_s, + prefill_kv_stride_d, + compact_kv_stride_s, + compact_kv_stride_d, + unique_num, + KV_NOPE_DIM: tl.constexpr, + KV_ROPE_DIM: tl.constexpr, + GROUP_SIZE: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_block = tl.program_id(1) + + if pid_s >= unique_num: + return + + mem_index = tl.load(unique_mem_index_ptr + pid_s) + prefill_row = tl.load(prefill_row_table_ptr + mem_index) + offs_d = tl.arange(0, BLOCK_D) + + if prefill_row != -1: + if pid_block < (KV_NOPE_DIM // GROUP_SIZE): + mask = offs_d < GROUP_SIZE + value = tl.load( + prefill_kv_ptr + + prefill_row * prefill_kv_stride_s + + (pid_block * GROUP_SIZE + offs_d) * prefill_kv_stride_d, + mask=mask, + ).to(tl.float32) + tl.store( + compact_kv_ptr + pid_s * compact_kv_stride_s + (pid_block * GROUP_SIZE + offs_d) * compact_kv_stride_d, + value, + mask=mask, + ) + else: + mask = offs_d < KV_ROPE_DIM + value = tl.load( + prefill_kv_ptr + prefill_row * prefill_kv_stride_s + (KV_NOPE_DIM + offs_d) * prefill_kv_stride_d, + mask=mask, + ).to(tl.float32) + tl.store( + compact_kv_ptr + pid_s * compact_kv_stride_s + (KV_NOPE_DIM + offs_d) * compact_kv_stride_d, + value, + mask=mask, + ) + else: + if pid_block < (KV_NOPE_DIM // GROUP_SIZE): + mask = offs_d < GROUP_SIZE + src_fp8 = tl.load( + packed_nope_ptr + + mem_index * packed_nope_stride_s + + (pid_block * GROUP_SIZE + offs_d) * packed_nope_stride_d, + mask=mask, + ) + scale = tl.load(packed_scale_ptr + mem_index * packed_scale_stride_s + pid_block * packed_scale_stride_d) + value = src_fp8.to(tl.float32) * scale + tl.store( + compact_kv_ptr + pid_s * compact_kv_stride_s + (pid_block * GROUP_SIZE + offs_d) * compact_kv_stride_d, + value, + mask=mask, + ) + else: + mask = offs_d < KV_ROPE_DIM + value = tl.load( + packed_rope_ptr + mem_index * packed_rope_stride_s + offs_d * packed_rope_stride_d, + mask=mask, + ).to(tl.float32) + tl.store( + compact_kv_ptr + pid_s * compact_kv_stride_s + (KV_NOPE_DIM + offs_d) * compact_kv_stride_d, + value, + mask=mask, + ) + + +@torch.no_grad() +def get_prefill_kv_cache_and_remap_indices_triton( + packed_kv: torch.Tensor, + topk_mem_indices: torch.Tensor, + prefill_mem_index: torch.Tensor, + prefill_cache_kv: torch.Tensor, + prefill_dtype: torch.dtype, +): + squeeze_h_kv = topk_mem_indices.ndim == 2 + if squeeze_h_kv: + topk_mem_indices = topk_mem_indices.unsqueeze(1) + + original_shape = topk_mem_indices.shape + flat_topk = topk_mem_indices.reshape(-1).contiguous().to(torch.int32) + + if flat_topk.numel() == 0: + empty_kv = torch.empty((0, 1, 576), dtype=prefill_dtype, device=packed_kv.device) + remapped = topk_mem_indices.clone() + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return empty_kv, remapped + + valid_mask = flat_topk != -1 + valid_topk = flat_topk[valid_mask] + if valid_topk.numel() == 0: + empty_kv = torch.empty((0, 1, 576), dtype=prefill_dtype, device=packed_kv.device) + remapped = torch.full(original_shape, -1, dtype=torch.int32, device=packed_kv.device) + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return empty_kv, remapped + + table_size = packed_kv.shape[0] + + prefill_row_table = torch.full((table_size,), -1, dtype=torch.int32, device=packed_kv.device) + _build_prefill_row_table_kernel[(prefill_mem_index.numel(),)]( + prefill_mem_index_ptr=prefill_mem_index.to(torch.int32).contiguous(), + row_table_ptr=prefill_row_table, + prefill_token_num=prefill_mem_index.numel(), + num_warps=4, + ) + + unique_mem_index, inverse = torch.unique(valid_topk, sorted=False, return_inverse=True) + unique_mem_index = unique_mem_index.to(torch.int32) + unique_count = unique_mem_index.numel() + remapped_flat = torch.full_like(flat_topk, -1) + remapped_flat[valid_mask] = inverse.to(torch.int32) + + compact_kv = torch.empty((unique_count, 1, 576), dtype=prefill_dtype, device=packed_kv.device) + packed_nope = packed_kv[:, :, :512].view(torch.float8_e4m3fn).view(-1, 512) + packed_scale = packed_kv[:, :, 512:528].view(torch.float32).view(-1, 4) + packed_rope = packed_kv[:, :, 528:].view(torch.bfloat16).view(-1, 64) + prefill_kv_2d = prefill_cache_kv.view(-1, 576) + compact_kv_2d = compact_kv.view(-1, 576) + + _fill_compact_kv_kernel[(unique_count, 5)]( + packed_nope_ptr=packed_nope, + packed_scale_ptr=packed_scale, + packed_rope_ptr=packed_rope, + unique_mem_index_ptr=unique_mem_index, + prefill_row_table_ptr=prefill_row_table, + prefill_kv_ptr=prefill_kv_2d, + compact_kv_ptr=compact_kv_2d, + packed_nope_stride_s=packed_nope.stride(0), + packed_nope_stride_d=packed_nope.stride(1), + packed_scale_stride_s=packed_scale.stride(0), + packed_scale_stride_d=packed_scale.stride(1), + packed_rope_stride_s=packed_rope.stride(0), + packed_rope_stride_d=packed_rope.stride(1), + prefill_kv_stride_s=prefill_kv_2d.stride(0), + prefill_kv_stride_d=prefill_kv_2d.stride(1), + compact_kv_stride_s=compact_kv_2d.stride(0), + compact_kv_stride_d=compact_kv_2d.stride(1), + unique_num=unique_count, + KV_NOPE_DIM=512, + KV_ROPE_DIM=64, + GROUP_SIZE=128, + BLOCK_D=128, + num_warps=4, + ) + + remapped = remapped_flat.view(original_shape) + if squeeze_h_kv: + remapped = remapped.squeeze(1) + return compact_kv, remapped diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index af39e5da7..5d95c2348 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -374,13 +374,15 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--llm_kv_type", type=str, - choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"], + choices=["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"], default="None", help="""kv type used in llm, None for dtype that llm used in config.json. fp8kv_sph: use float8_e4m3fn to store kv cache for inference, quant way is static per head kv quant. fp8kv_spt: use float8_e4m3fn to store kv cache for inference, quant way is static per tensor kv quant. + fp8kv_dsa: use DeepSeek-V3.2 DSA-specific FlashMLA FP8 sparse KV cache, + intended for the deepseek_v32 model path. fp8kv_sph and fp8kv_spt requires --kv_quant_calibration_config_path to load pre-computed FP8 scales. Note: fp8kv_spt requires flashinfer-python>=0.6.5 (default is 0.6.3, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index c7c3975f7..f980f0ad1 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -127,7 +127,7 @@ class StartArgs: default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) llm_kv_type: str = field( - default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]} + default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt", "fp8kv_dsa"]} ) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})