diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b96fff0d3f..f3bc5b5fd4 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -82,6 +82,7 @@ class QuantizationType(str, Enum): TE_FP8_CS = "te_fp8_currentscaling" TE_MXFP8 = "te_mxfp8" TE_NVFP4 = "te_nvfp4" + TE_NVFP4_NO_RHT = "te_nvfp4_no_rht" class KvQuantAxis(str, Enum): diff --git a/src/MaxText/layers/quantizations.py b/src/MaxText/layers/quantizations.py index d0f9353b6c..8228bc64d3 100644 --- a/src/MaxText/layers/quantizations.py +++ b/src/MaxText/layers/quantizations.py @@ -752,6 +752,7 @@ def _get_recipe(recipe_name: str): "te_fp8_currentscaling": recipe.Float8CurrentScaling, "te_mxfp8": recipe.MXFP8BlockScaling, "te_nvfp4": recipe.NVFP4BlockScaling, # pytype: disable=module-attr + "te_nvfp4_no_rht": functools.partial(recipe.NVFP4BlockScaling, disable_rht=True), # pytype: disable=module-attr } if recipe_name not in RECIPES: raise ValueError(f"Invalid TransformerEngine recipe: {recipe_name}")