Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu22.04
ARG PYTHON_VERSION=3.10
ARG MAMBA_VERSION=24.7.1-0
ARG VLLM_VERSION=0.16.0
ARG FLASH_MLA_REF=47c35a7
ARG TARGETPLATFORM
ARG ENABLE_DEEPEP=1
ARG ENABLE_NIXL=1
Expand Down Expand Up @@ -45,6 +46,11 @@ COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -U pip
RUN pip install -r /lightllm/requirements.txt --no-cache-dir
RUN pip install --no-cache-dir vllm==${VLLM_VERSION}
RUN git clone https://github.com/deepseek-ai/FlashMLA.git /root/FlashMLA && \
cd /root/FlashMLA && \
git checkout ${FLASH_MLA_REF} && \
git submodule update --init --recursive && \
FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir .
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To reduce the final Docker image size, it's a good practice to clean up build-time dependencies and source files within the same RUN layer. After installing FlashMLA, the cloned repository at /root/FlashMLA is no longer needed and can be removed.

    FLASH_MLA_DISABLE_SM100=1 pip install --no-cache-dir . && rm -rf /root/FlashMLA


RUN apt-get update && apt-get install -y libnuma-dev && rm -rf /var/lib/apt/lists/*

Expand Down
1 change: 1 addition & 0 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
Expand Down
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -56,6 +57,9 @@
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
"fp8kv_dsa": {
"flashmla_sparse": NsaFlashMlaFp8SparseAttBackend,
},
}


Expand Down
8 changes: 8 additions & 0 deletions lightllm/common/basemodel/attention/nsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@
NsaFlashMlaSparsePrefillAttState,
NsaFlashMlaSparseDecodeAttState,
)
from .fp8_flashmla_sparse import (
NsaFlashMlaFp8SparseAttBackend,
NsaFlashMlaFp8SparsePrefillAttState,
NsaFlashMlaFp8SparseDecodeAttState,
)

__all__ = [
"NsaFlashMlaSparseAttBackend",
"NsaFlashMlaSparsePrefillAttState",
"NsaFlashMlaSparseDecodeAttState",
"NsaFlashMlaFp8SparseAttBackend",
"NsaFlashMlaFp8SparsePrefillAttState",
"NsaFlashMlaFp8SparseDecodeAttState",
]
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _nsa_decode_att(
from sgl_kernel.flash_attn import flash_attn_with_kvcache

nsa_dict = att_control.nsa_decode_dict
topk_indices = nsa_dict["topk_indices"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
qk_rope_head_dim = nsa_dict["qk_rope_head_dim"]
Expand All @@ -181,7 +181,7 @@ def _nsa_decode_att(
k_cache=k_rope,
v_cache=kv_nope,
qv=q_nope,
page_table=topk_indices,
page_table=topk_mem_indices,
cache_seqlens=self.nsa_cache_seqlens,
cu_seqlens_q=self.infer_state.b1_cu_q_seq_len,
cu_seqlens_k_new=self.nsa_cu_seqlens_k_new,
Expand Down
191 changes: 191 additions & 0 deletions lightllm/common/basemodel/attention/nsa/fp8_flashmla_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import dataclasses
import torch
from typing import TYPE_CHECKING, Tuple

from ..base_att import AttControl, BaseAttBackend, BaseDecodeAttState, BasePrefillAttState
from lightllm.utils.dist_utils import get_current_device_id

if TYPE_CHECKING:
from lightllm.common.basemodel.infer_struct import InferStateInfo


class NsaFlashMlaFp8SparseAttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
device = get_current_device_id()
self.ragged_mem_buffers = [
torch.empty(model.graph_max_batch_size * model.max_seq_length, dtype=torch.int32, device=device)
for _ in range(2)
]

def create_att_prefill_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparsePrefillAttState":
return NsaFlashMlaFp8SparsePrefillAttState(backend=self, infer_state=infer_state)

def create_att_decode_state(self, infer_state: "InferStateInfo") -> "NsaFlashMlaFp8SparseDecodeAttState":
return NsaFlashMlaFp8SparseDecodeAttState(backend=self, infer_state=infer_state)


@dataclasses.dataclass
class NsaFlashMlaFp8SparsePrefillAttState(BasePrefillAttState):
ks: torch.Tensor = None
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)
from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.total_token_num - self.infer_state.prefix_total_token_num,
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
return

def prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_prefill, "nsa_prefill must be True for NSA prefill attention"
assert att_control.nsa_prefill_dict is not None, "nsa_prefill_dict is required"
return self._nsa_prefill_att(q=q, att_control=att_control)

def _nsa_prefill_att(
self,
q: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla

nsa_dict = att_control.nsa_prefill_dict
topk_indices = nsa_dict["topk_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]
layer_index = nsa_dict["layer_index"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
prefill_cache_kv = nsa_dict["prefill_cache_kv"]

if self.infer_state.prefix_total_token_num > 0:
kv, topk_indices = self.infer_state.mem_manager.get_prefill_kv_cache_and_remap_indices(
layer_index=layer_index,
topk_indices=topk_mem_indices,
prefill_mem_index=self.infer_state.mem_index,
prefill_cache_kv=prefill_cache_kv,
)
else:
kv = prefill_cache_kv

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)

mla_out, _, _ = flash_mla.flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
return mla_out


@dataclasses.dataclass
class NsaFlashMlaFp8SparseDecodeAttState(BaseDecodeAttState):
ks: torch.Tensor = None
ke: torch.Tensor = None
lengths: torch.Tensor = None
ragged_mem_index: torch.Tensor = None
flashmla_sched_meta: object = None

def init_state(self):
self.backend: NsaFlashMlaFp8SparseAttBackend = self.backend
model = self.backend.model
use_cuda_graph = (
self.infer_state.batch_size <= model.graph_max_batch_size
and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch
)

if use_cuda_graph:
self.ragged_mem_index = self.backend.ragged_mem_buffers[self.infer_state.microbatch_index]
else:
self.ragged_mem_index = torch.empty(
self.infer_state.total_token_num,
dtype=torch.int32,
device=get_current_device_id(),
)

from lightllm.common.basemodel.triton_kernel.gen_nsa_ks_ke import gen_nsa_ks_ke

self.ks, self.ke, self.lengths = gen_nsa_ks_ke(
b_seq_len=self.infer_state.b_seq_len,
b_q_seq_len=self.infer_state.b_q_seq_len,
b_req_idx=self.infer_state.b_req_idx,
req_to_token_index=self.infer_state.req_manager.req_to_token_indexs,
q_token_num=self.infer_state.b_seq_len.shape[0],
ragged_mem_index=self.ragged_mem_index,
hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID,
)
import flash_mla

self.flashmla_sched_meta, _ = flash_mla.get_mla_metadata()
return

def decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.nsa_decode, "nsa_decode must be True for NSA decode attention"
assert att_control.nsa_decode_dict is not None, "nsa_decode_dict is required"
return self._nsa_decode_att(q=q, kv=k, att_control=att_control)

def _nsa_decode_att(
self,
q: Tuple[torch.Tensor, torch.Tensor],
kv: torch.Tensor,
att_control: AttControl,
) -> torch.Tensor:
import flash_mla

nsa_dict = att_control.nsa_decode_dict
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)
assert topk_mem_indices.shape[1] == 1, "FlashMLA sparse decode path currently expects seq_len_q == 1"

q_nope, q_rope = q
q_all = torch.cat([q_nope, q_rope], dim=-1).unsqueeze(1).contiguous()

o_tensor, _ = flash_mla.flash_mla_with_kvcache(
q=q_all,
k_cache=kv,
block_table=None,
cache_seqlens=None,
head_dim_v=kv_lora_rank,
tile_scheduler_metadata=self.flashmla_sched_meta,
num_splits=None,
softmax_scale=softmax_scale,
causal=False,
is_fp8_kvcache=True,
indices=topk_mem_indices.contiguous(),
)
return o_tensor[:, 0, :, :] # [b, 1, h, d] -> [b, h, d]
2 changes: 2 additions & 0 deletions lightllm/common/kv_cache_mem_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
from .deepseek2_mem_manager import Deepseek2MemoryManager
from .deepseek3_2mem_manager import Deepseek3_2MemoryManager
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager

Expand All @@ -13,6 +14,7 @@
"PPLINT8KVMemoryManager",
"Deepseek2MemoryManager",
"Deepseek3_2MemoryManager",
"FP8PerTokenGroupQuantDeepseek3_2MemoryManager",
"FP8StaticPerHeadQuantMemManager",
"FP8StaticPerTensorQuantMemManager",
]
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv:
def get_att_input_params(self, layer_index: int) -> Any:
kv = self.kv_buffer[layer_index][:, :, : (self.head_dim - (144 // 2))]
return kv

def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
return self.kv_buffer[layer_index].view(dtype=torch.uint8)[:, :, -132:]
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from typing import Any

from .deepseek2_mem_manager import Deepseek2MemoryManager


class FP8PerTokenGroupQuantDeepseek3_2MemoryManager(Deepseek2MemoryManager):
kv_nope_dim = 512
kv_rope_dim = 64
# 576 = 512 + 64
kv_head_dim = kv_nope_dim + kv_rope_dim

quant_group_size = 128
# 4 = 512 / 128
quant_group_num = kv_nope_dim // quant_group_size
# 4 * 4 = quant_group_num * fp32
# 64 * 2 = kv_rope_dim * bfloat16
# 656 bytes = 512 + (4 * 4) + (64 * 2)
flashmla_bytes_per_token = kv_nope_dim + quant_group_num * 4 + kv_rope_dim * 2

indexer_head_dim = 128
# 128 + 4 = indexer_head_dim + fp32
# 132 bytes = 128 + 4
indexer_bytes_per_token = indexer_head_dim + 4

# Merged per-token layout in kv_buffer:
# [flashmla (656 bytes) | indexer (132 bytes) | padding (12 bytes)] = 800 bytes total
# Padded to 16-byte alignment so that FlashMLA CUDA kernel vectorized
# loads (uint4 = 16 bytes) always access aligned addresses.
# This allows parent class PD/move/page/p2p methods to transfer both
# flashmla and indexer data together in a single kv_buffer.
_ALIGNMENT = 16
total_bytes_per_token = (
(flashmla_bytes_per_token + indexer_bytes_per_token + _ALIGNMENT - 1) // _ALIGNMENT * _ALIGNMENT
)

def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
assert head_num == 1, "DeepSeek-V3.2 DSA FP8 path expects MQA-style head_num == 1"
self.prefill_dtype = dtype
super().__init__(size, torch.uint8, head_num, self.total_bytes_per_token, layer_num, always_copy, mem_fraction)

def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_kv_flashmla_fp8 import (
destindex_copy_kv_flashmla_fp8,
)

rope_dim = 64
kv_lora_rank = kv.shape[2] - rope_dim
assert kv_lora_rank == 512, f"Expected kv_lora_rank=512, got {kv_lora_rank}"

o_nope = self.kv_buffer[layer_index][:, :, :512].view(torch.float8_e4m3fn)
o_scale = self.kv_buffer[layer_index][:, :, 512:528].view(torch.float32)
o_rope = self.kv_buffer[layer_index][:, :, 528 : self.flashmla_bytes_per_token].view(torch.bfloat16)
destindex_copy_kv_flashmla_fp8(
kv[:, :, :kv_lora_rank],
kv[:, :, kv_lora_rank:],
mem_index,
o_nope,
o_scale,
o_rope,
)

def get_att_input_params(self, layer_index: int) -> Any:
# Return an as_strided view: shape (N, 1, 1, 656) with stride (788, 656, 656, 1).
# FlashMLA CUDA kernel uses stride(0) for block offset and stride(1) for row offset,
# so it correctly reads 656 contiguous bytes per token even though tokens are 788 bytes apart.
# This avoids any data copy while passing FlashMLA's shape and stride checks.
buf = self.kv_buffer[layer_index]
return torch.as_strided(
buf,
size=(buf.shape[0], 1, 1, self.flashmla_bytes_per_token),
stride=(buf.stride(0), self.flashmla_bytes_per_token, self.flashmla_bytes_per_token, 1),
)

def get_indexer_k_buffer(self, layer_index: int) -> torch.Tensor:
indexer_end = self.flashmla_bytes_per_token + self.indexer_bytes_per_token
return self.kv_buffer[layer_index][:, :, self.flashmla_bytes_per_token : indexer_end]

def get_prefill_kv_cache_and_remap_indices(
self,
layer_index: int,
topk_indices: torch.Tensor,
prefill_mem_index: torch.Tensor,
prefill_cache_kv: torch.Tensor,
):
from lightllm.models.deepseek3_2.triton_kernel.prefill_compact_kv_flashmla_fp8 import (
get_prefill_kv_cache_and_remap_indices_triton,
)

return get_prefill_kv_cache_and_remap_indices_triton(
packed_kv=self.kv_buffer[layer_index][:, :, : self.flashmla_bytes_per_token],
topk_mem_indices=topk_indices,
prefill_mem_index=prefill_mem_index,
prefill_cache_kv=prefill_cache_kv,
prefill_dtype=self.prefill_dtype,
)
Loading