Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 0 additions & 67 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch.testing._internal.optests import opcheck

import torchao
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.quantization.marlin_qqq import (
marlin_qqq_workspace,
pack_to_marlin_qqq,
Expand Down Expand Up @@ -56,72 +55,6 @@


class TestOps(TestCase):
def _create_floatx_inputs(
self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype
):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).to(dtype) + 0.5
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear(self, ebits, mbits, dtype):
BS = 2
OC = 256
IC = 256
splitK = 1
floatx_weight, scale, fp16_act = self._create_floatx_inputs(
ebits, mbits, BS, OC, IC, "cuda", dtype
)

# smoke test
torchao.ops.quant_llm_linear(
ebits, mbits, fp16_act, floatx_weight, scale, splitK
)

# comprehensive testing
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
]
opcheck(
torch.ops.torchao.quant_llm_linear,
(ebits, mbits, fp16_act, floatx_weight, scale, splitK),
test_utils=test_utils,
)

@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear_correctness(
self, ebits, mbits, BS, OC, IC, splitK, dtype
):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
floatx_weight, scale, fp16_act = self._create_floatx_inputs(
ebits, mbits, BS, OC, IC, "cuda", dtype
)

results_floatx = torchao.ops.quant_llm_linear(
ebits, mbits, fp16_act, floatx_weight, scale, splitK
)

fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(
dtype
)
results_fp16 = fp16_act @ fp16_weight.T

error = (results_floatx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
assert relative_error < rtol

def _scaled_dot_product_int8_op_ref(
self,
Expand Down
7 changes: 0 additions & 7 deletions torchao/csrc/cuda/fp6_llm/README.md

This file was deleted.

73 changes: 0 additions & 73 deletions torchao/csrc/cuda/fp6_llm/configs.h

This file was deleted.

Loading
Loading