Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/MaxText/configs/models/deepseek3-671b-2dfsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ mesh_axes: ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']
data_sharding: [['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_norm_length', ['context']],
Expand All @@ -68,10 +69,12 @@ logical_axis_rules: [
['embed_no_exp', ['fsdp']],
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['q_lora_up_proj', ['fsdp_transpose']],
['kv_lora_up_proj', ['fsdp_transpose']],
['q_heads', ['fsdp_transpose']],
['kv_heads', ['fsdp_transpose']],
['heads', ['fsdp_transpose']],
['mlp', ['fsdp_transpose']],
['q_lora_up_proj', ['fsdp_transpose', 'expert']],
['kv_lora_up_proj', ['fsdp_transpose', 'expert']],
['q_heads', ['fsdp_transpose', 'expert']],
['kv_heads', ['fsdp_transpose', 'expert']],
['heads', ['fsdp_transpose', 'expert']],
['mlp', ['fsdp_transpose', 'expert']],
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
['mlp_only_tensor', ['expert']],
]
50 changes: 0 additions & 50 deletions src/MaxText/configs/models/deepseek3-tiny.yml

This file was deleted.

127 changes: 127 additions & 0 deletions src/MaxText/kernels/sort_activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Token sorting for MoE layers."""

import functools

import jax
import jax.numpy as jnp


@functools.partial(jax.custom_vjp, nondiff_argnums=(2,))
def route(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> jax.Array:
"""Route tokens to selected experts."""
return _route_fwd(tokens, selected_experts, use_custom_mosaic_kernel)[0]


def _route_fwd(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> tuple[jax.Array, jax.Array]:
return (
_route_impl(tokens, selected_experts, use_custom_mosaic_kernel),
selected_experts,
)


def _route_bwd(
use_custom_mosaic_kernel: bool,
residuals: jax.Array,
grads: jax.Array,
) -> tuple[jax.Array, None]:
selected_experts = residuals
return _unroute_impl(grads, selected_experts, use_custom_mosaic_kernel), None


route.defvjp(_route_fwd, _route_bwd)


@functools.partial(jax.custom_vjp, nondiff_argnums=(2,))
def unroute(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> jax.Array:
return _unroute_fwd(tokens, selected_experts, use_custom_mosaic_kernel)[0]


def _unroute_fwd(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> tuple[jax.Array, jax.Array]:
return (
_unroute_impl(tokens, selected_experts, use_custom_mosaic_kernel),
selected_experts,
)


def _unroute_bwd(
use_custom_mosaic_kernel: bool, residuals: jax.Array, grads: jax.Array
) -> tuple[jax.Array, None]:
selected_experts = residuals
return _route_impl(grads, selected_experts, use_custom_mosaic_kernel), None


unroute.defvjp(_unroute_fwd, _unroute_bwd)


def _route_impl(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> jax.Array:
"""Gather `tokens` according to `selected_experts`."""
assert (
tokens.shape[0] == selected_experts.shape[0]
and selected_experts.ndim == 2
), f"{tokens.shape=}, {selected_experts.shape=}"
if use_custom_mosaic_kernel:
raise NotImplementedError("Custom Mosaic kernel not implemented.")
inds = jnp.argsort(jnp.ravel(selected_experts)) // selected_experts.shape[1]
return _sort_impl(tokens, inds, use_custom_mosaic_kernel)


def _unroute_impl(
tokens: jax.Array,
selected_experts: jax.Array,
use_custom_mosaic_kernel: bool,
) -> jax.Array:
assert (
tokens.shape[0] == selected_experts.shape[0] * selected_experts.shape[1]
and selected_experts.ndim == 2
)
inds = jnp.argsort(jnp.argsort(jnp.ravel(selected_experts)))
return jnp.sum(
jnp.reshape(
_sort_impl(tokens, inds, use_custom_mosaic_kernel),
(-1, selected_experts.shape[1]) + tokens.shape[1:],
),
axis=1,
)


def _sort_impl(
tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool
) -> jax.Array:
if use_custom_mosaic_kernel:
raise NotImplementedError("Custom Mosaic kernel not implemented.")
else:
return tokens[inds, ...]
9 changes: 4 additions & 5 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers import (
deepseek,
deepseek_batchsplit,
gemma,
gemma2,
gemma3,
Expand Down Expand Up @@ -405,10 +404,10 @@ def get_decoder_layers(self):
case DecoderBlockType.MIXTRAL:
return [mixtral.MixtralDecoderLayerToLinen]
case DecoderBlockType.DEEPSEEK:
if self.config.use_batch_split_schedule:
return [deepseek_batchsplit.DeepSeekDenseLayerToLinen, deepseek_batchsplit.DeepSeekMoELayerToLinen]
else:
return [deepseek.DeepSeekDenseLayerToLinen, deepseek.DeepSeekMoELayerToLinen]
return [
deepseek.DeepSeekDenseLayerToLinen,
deepseek.DeepSeekMoELayerToLinen,
]
case DecoderBlockType.GEMMA:
return [gemma.GemmaDecoderLayerToLinen]
case DecoderBlockType.GEMMA2:
Expand Down
28 changes: 22 additions & 6 deletions src/MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,27 @@

from typing import Optional

from flax import nnx
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
import jax.numpy as jnp

from flax import nnx

from jax.sharding import Mesh
from MaxText.common_types import Config
from MaxText.common_types import MODEL_MODE_PREFILL
from maxtext.inference import page_manager
from MaxText.layers import attention_mla
from MaxText.layers import deepseek_batchsplit
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import moe
from MaxText.layers import nnx_wrappers
from MaxText.layers import quantizations
from MaxText.layers.linears import Dropout
from MaxText.layers.normalizations import RMSNorm
from MaxText.sharding import maybe_shard_with_logical, create_sharding
from maxtext.inference import page_manager
from MaxText.sharding import create_sharding
from MaxText.sharding import maybe_shard_with_logical
from maxtext.utils import max_utils


# -----------------------------------------
# The Decoder Layer for DeepSeek v3
# -----------------------------------------
Expand Down Expand Up @@ -366,6 +367,21 @@ def __call__(
# Unpack inputs if it's a tuple (e.g. from a previous layer returning (hidden_states, kv_cache))
if isinstance(inputs, tuple):
inputs = inputs[0]

# If using batch split schedule, call the batch split version of the layer.
if self.config.use_batch_split_schedule:
outputs = deepseek_batchsplit.batch_split_schedule(
inputs,
nnx.to_pure_dict(nnx.state(self, nnx.Param)),
decoder_positions,
decoder_segment_ids,
model_mode=model_mode,
mesh=self.mesh,
quant=self.quant,
cfg=self.config,
)
return outputs, None

x = self.with_logical_constraint(inputs)
x = checkpoint_name(x, "decoder_layer_input")

Expand Down
Loading
Loading