Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,27 @@ def is_mx_format(self):
and self.block_sizes.get("scale_bits", None) == (8, 0)
)

@property
def is_mxfp4(self):
"""Check if is MXFP4."""
return (
self.is_mx_format and self.num_bits == (2, 1) and self.block_sizes.get(-1, None) == 32
)

@property
def is_mxfp6(self):
"""Check if is MXFP6."""
return (
self.is_mx_format and self.num_bits == (3, 2) and self.block_sizes.get(-1, None) == 32
)

@property
def is_mxfp8(self):
"""Check if is MXFP8."""
return (
self.is_mx_format and self.num_bits == (4, 3) and self.block_sizes.get(-1, None) == 32
)

@property
def is_static_block_quant(self):
"""Check if is static block quantization."""
Expand Down
92 changes: 88 additions & 4 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
except ImportError:
Shard = None

try:
import kitchen
from kitchen.fa import KitchenFlashAttentionModule
from kitchen.triton_module import triton_fa_params
except ImportError:
kitchen = None

import torch.nn as nn
import transformers
from transformers.models.t5.modeling_t5 import T5Attention
Expand Down Expand Up @@ -56,17 +63,94 @@ def _setup(self):
self.q_bmm_quantizer = TensorQuantizer()
self.k_bmm_quantizer = TensorQuantizer()
self.v_bmm_quantizer = TensorQuantizer()
self.softmax_quantizer = TensorQuantizer()
self.kitchen_attn_fn = None
self.use_kitchen = False

def _init_kitchen_attn_fn(self):
if not self.softmax_quantizer.is_enabled:
self.kitchen_attn_fn = "disabled"
return
self.use_kitchen = True
if self.softmax_quantizer.is_mxfp8:
qfa_params = triton_fa_params.QTritonFAParams(
backend="triton",
qk_dot_precisions="bf16@bf16",
pv_dot_precisions="mxfp8_e4m3_emulation@bf16",
dp_v_x_do_dot_precisions="bf16@bf16",
dp_do_x_v_dot_precisions="bf16@bf16",
dq_ds_x_k_dot_precisions="bf16@bf16",
dk_ds_x_q_dot_precisions="bf16@bf16",
dv_p_x_do_dot_precisions="bf16@bf16",
use_natural_transcendental_func=False, # Different from default
)
else:
raise NotImplementedError(f"softmax_quantizer not supported: {self.softmax_quantizer}")

self.kitchen_attn_fn = KitchenFlashAttentionModule(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.head_dim,
num_gqa_groups=None, # self.config.num_key_value_heads, kitchen does not support gqa.
attention_dropout=self.config.attention_dropout,
qkv_format="sbhd", # this is not used at all, but in forward, this is the only supported format.
attn_mask_type="causal",
window_size=getattr(self.config, "sliding_window", None),
sequence_parallel=False,
get_rng_state_tracker=None,
layer_number=None,
attention_type="self",
softmax_scale=None, # This will be convert to the same default as sdpa: 1/sqrt(dim_q)
qfa_params=qfa_params,
)

@staticmethod
def _quantized_attention(
original_attention_interface, self, query_states, key_states, value_states, *args, **kwargs
original_attention_interface,
self,
query_states,
key_states,
value_states,
*args,
**kwargs,
):
if kitchen is not None and self.kitchen_attn_fn is None:
self._init_kitchen_attn_fn()

query_states = self.q_bmm_quantizer(query_states)
key_states = self.k_bmm_quantizer(key_states)
value_states = self.v_bmm_quantizer(value_states)
return original_attention_interface(
self, query_states, key_states, value_states, *args, **kwargs
)
if not self.use_kitchen:
return original_attention_interface(
self, query_states, key_states, value_states, *args, **kwargs
)

query_sequence_length = query_states.shape[2]
if query_states.shape[2] < key_states.shape[2]: # For decoding stage.
shape = list(query_states.shape)
shape[2] = key_states.shape[2] - query_states.shape[2]
query_states = torch.cat(
[
torch.empty(shape, dtype=query_states.dtype, device=query_states.device),
query_states,
],
dim=2,
)

n_repeat = self.config.num_attention_heads // self.config.num_key_value_heads
if n_repeat > 1:
key_states = key_states.repeat_interleave(n_repeat, dim=1)
value_states = value_states.repeat_interleave(n_repeat, dim=1)
# kitchen only supports sbhd. we have bhsd.
query_states = query_states.permute(2, 0, 1, 3)
key_states = key_states.permute(2, 0, 1, 3)
value_states = value_states.permute(2, 0, 1, 3)
attn_out = self.kitchen_attn_fn(query_states, key_states, value_states)
attn_out = attn_out[-query_sequence_length:, :, :]
# output is sb(h*d), we need bshd
attn_out = attn_out.reshape(
(attn_out.shape[0], attn_out.shape[1], query_states.shape[2], -1)
).permute(1, 0, 2, 3)
return attn_out.contiguous(), None

def forward(self, *args, **kwargs):
"""Forward method for KV cache quantization compatible with new_attention_interface in transformers >= 4.48.0.
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/torch/quantization/plugins/test_attention_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from _test_utils.torch.transformers_models import get_tiny_bert, get_tiny_llama, get_tiny_t5
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention

try:
import kitchen
except ImportError:
kitchen = None

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.plugins.huggingface import _QuantAttention
Expand Down Expand Up @@ -54,6 +63,7 @@ def forward(self, hidden_states, **kwargs):
kv_cache_config = {
"quant_cfg": {
"*[kv]_bmm_quantizer": {"num_bits": 4, "enable": True},
"*softmax_quantizer": {"enable": False},
},
"algorithm": "max",
}
Expand Down Expand Up @@ -147,3 +157,77 @@ def test_kv_quant_bert():
assert output is not None
assert output.start_logits is not None
assert output.end_logits is not None


@pytest.mark.skipif(kitchen is None, reason="kitchen is not installed.")
def test_kitchen_fa():
batch_size = 2
num_q_heads = 4
num_kv_heads = 2
seqlen = 8
hidden_size = 128

config = LlamaConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
)
original_attention = LlamaAttention(config, layer_idx=0)

q_states = torch.randn(
batch_size, num_q_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda"
)
k_states = torch.randn(
batch_size, num_kv_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda"
)
v_states = torch.randn(
batch_size, num_kv_heads, seqlen, hidden_size, dtype=torch.bfloat16, device="cuda"
)

# Convert it to _QuantAttention using the convert() class method
quant_attention = _QuantAttention.convert(original_attention)
quant_attention.config._attn_implementation = "sdpa"
assert hasattr(quant_attention, "q_bmm_quantizer")
assert hasattr(quant_attention, "k_bmm_quantizer")
assert hasattr(quant_attention, "v_bmm_quantizer")
assert hasattr(quant_attention, "softmax_quantizer")
quant_attention.softmax_quantizer.disable()
module = inspect.getmodule(quant_attention.get_attn_type(quant_attention))
orig_attn_fn = module.ALL_ATTENTION_FUNCTIONS["sdpa"]

output = quant_attention._quantized_attention(
orig_attn_fn,
quant_attention,
q_states,
k_states,
v_states,
attention_mask=None,
)
expected = output[0]

config = LlamaConfig(
hidden_size=hidden_size,
num_attention_heads=num_q_heads,
num_key_value_heads=num_kv_heads,
)
original_attention = LlamaAttention(config, layer_idx=0)
quant_attention = _QuantAttention.convert(original_attention)
quant_attention.config._attn_implementation = "sdpa"
quant_attention.softmax_quantizer.num_bits = (4, 3)
quant_attention.softmax_quantizer.block_sizes = {
-1: 32,
"type": "dynamic",
"scale_bits": (8, 0),
}
output = quant_attention._quantized_attention(
None,
quant_attention,
q_states,
k_states,
v_states,
attention_mask=None,
)
diff = (expected - output[0]).abs()
assert torch.allclose(expected, output[0], atol=0.75, rtol=0.75), (
f"{diff.max().item(), diff.mean().item(), diff.std().item()}"
)
Loading