diff --git a/benchmarks/microbenchmarks/test/test_benchmark_inference.py b/benchmarks/microbenchmarks/test/test_benchmark_inference.py
index f3e853866d..23c5173760 100644
--- a/benchmarks/microbenchmarks/test/test_benchmark_inference.py
+++ b/benchmarks/microbenchmarks/test/test_benchmark_inference.py
@@ -49,37 +49,6 @@ def test_run_inference(self, mock_string_to_config):
hasattr(result, "quantized_model_compiled_inference_time_in_ms")
)
- @patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
- def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
- """Test running inference with sparsity configurations"""
- # Mock string_to_config to return valid configs
- from torchao.dtypes import MarlinSparseLayout
- from torchao.quantization import Int4WeightOnlyConfig
-
- # Test with semi-sparse config
- mock_string_to_config.return_value = Int4WeightOnlyConfig(
- layout=MarlinSparseLayout(),
- version=1,
- )
- config = BenchmarkConfig(
- quantization="marlin",
- sparsity="semi-sparse",
- params={
- "high_precision_dtype": "torch.float32",
- "device": "cpu",
- "model_type": "linear",
- },
- shape_name="custom",
- shape=[64, 64, 64], # Use dimensions divisible by 64
- output_dir=self.temp_dir,
- benchmark_mode="inference",
- )
- result = run(config)
- self.assertIsInstance(result, BenchmarkResult)
- self.assertTrue(
- hasattr(result, "quantized_model_compiled_inference_time_in_ms")
- )
-
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
def test_run_inference_with_block_sparsity(self, mock_string_to_config):
"""Test running inference with sparsity configurations"""
diff --git a/benchmarks/microbenchmarks/utils.py b/benchmarks/microbenchmarks/utils.py
index 2c6a443a86..26e746a3f5 100644
--- a/benchmarks/microbenchmarks/utils.py
+++ b/benchmarks/microbenchmarks/utils.py
@@ -19,7 +19,6 @@
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
@@ -195,18 +194,6 @@ def string_to_config(
return Int8DynamicActivationInt8WeightConfig(weight_only_decode=True)
else:
return Int8DynamicActivationInt8WeightConfig()
- if "int4wo" in quantization:
- use_hqq = False
- if "hqq" in quantization:
- use_hqq = True
- group_size = int(quantization.split("-")[1])
- assert group_size in [
- 32,
- 64,
- 128,
- 256,
- ], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
- return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1)
elif "int8adq-int4w-symm" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout
@@ -226,10 +213,6 @@ def string_to_config(
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
)
- elif sparsity is not None and ("semi" in sparsity or "2:4" in sparsity):
- from torchao.dtypes import MarlinSparseLayout
-
- return Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
if "fp6" in quantization:
return FPXWeightOnlyConfig(3, 2)
elif "uintx" in quantization:
diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst
index 64818f53ef..e1985aea26 100644
--- a/docs/source/serialization.rst
+++ b/docs/source/serialization.rst
@@ -36,7 +36,7 @@ Here is the serialization and deserialization flow::
print(f"original model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
- quantize_(m, Int4WeightOnlyConfig(version=1))
+ quantize_(m, Int4WeightOnlyConfig())
print(f"quantized model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
ref = m(*example_inputs)
diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md
index ef3719cb9f..0ef745a0ac 100644
--- a/docs/source/torchao_vllm_integration.md
+++ b/docs/source/torchao_vllm_integration.md
@@ -44,8 +44,6 @@ from torchao.quantization import Int4WeightOnlyConfig
# Example configuration
config = Int4WeightOnlyConfig(
group_size=128,
- use_hqq=True,
- version=1,
)
assert isinstance(config, AOBaseConfig)
```
@@ -66,7 +64,7 @@ config = FqnToConfig({
"model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
- "_default": Int4WeightOnlyConfig(group_size=128, version=1) # Default for other modules
+ "_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules
})
```
(usage-examples)=
@@ -81,7 +79,7 @@ from torchao.quantization import Int4WeightOnlyConfig
# Create quantization configuration
quantization_config = TorchAoConfig(
- quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1)
+ quant_type=Int4WeightOnlyConfig(group_size=128)
)
# Load and automatically quantize the model
diff --git a/scripts/quick_start.py b/scripts/quick_start.py
index 482919c620..0c13a79d01 100644
--- a/scripts/quick_start.py
+++ b/scripts/quick_start.py
@@ -38,8 +38,7 @@ def forward(self, x):
# | torchao quantization |
# ========================
-# torch 2.4+ only
-quantize_(model, Int4WeightOnlyConfig(group_size=32, version=1))
+quantize_(model, Int4WeightOnlyConfig(group_size=32))
# =============
diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 20cf1c6311..e799a5b655 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -17,8 +17,6 @@
from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
CutlassInt4PackedLayout,
- Int4CPULayout,
- Int4XPULayout,
PlainLayout,
SemiSparseLayout,
to_affine_quantized_intx,
@@ -28,14 +26,13 @@
Float8WeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
-from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
-from torchao.testing.utils import skip_if_no_cuda, skip_if_no_gemlite, skip_if_rocm
+from torchao.quantization.quant_primitives import MappingType
+from torchao.testing.utils import skip_if_no_gemlite, skip_if_rocm
from torchao.utils import (
check_cpu_version,
check_xpu_version,
@@ -62,24 +59,10 @@ def get_quantization_functions(
]
if do_int4:
if check_cpu_version(device):
- base_functions.append(
- Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1)
- )
+ pass
elif check_xpu_version(device):
- base_functions.append(
- Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1)
- )
- if int4_zp_int:
- base_functions.append(
- Int4WeightOnlyConfig(
- group_size=32,
- layout=Int4XPULayout(),
- zero_point_domain=ZeroPointDomain.INT,
- version=1,
- )
- )
+ pass
else:
- base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1))
if device == "cuda" and not is_ROCM():
base_functions.append(
Int8DynamicActivationInt4WeightConfig(
@@ -107,26 +90,6 @@ class TestAffineQuantized(TestCase):
["xpu"] if torch.xpu.is_available() else []
)
- @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
- def test_tensor_core_layout_transpose(self):
- linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=_DEVICE)
- t = linear.weight
- shape = t.shape
- apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
- quantize_(linear, apply_int4_weight_only_quant)
- ql = linear
- aqt = ql.weight
- aqt_shape = aqt.shape
- self.assertEqual(aqt_shape, shape)
-
- # transpose shape test
- for _ in range(10):
- t = t.t()
- aqt = aqt.t()
- shape = t.shape
- aqt_shape = aqt.shape
- self.assertEqual(aqt_shape, shape)
-
@unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available")
def test_weights_only(self):
for device in self.GPU_DEVICES:
@@ -338,20 +301,6 @@ def test_alias(self, device, dtype):
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
_ = dummy.weight[...]
- @common_utils.parametrize("device", [_DEVICE])
- @common_utils.parametrize("dtype", [torch.bfloat16])
- @skip_if_no_cuda()
- @skip_if_rocm("ROCm enablement in progress")
- def test_slice_int4wo(self, device, dtype):
- # in_feature not divisible by 1024
- # out_feature not divisible by 8
- # to test slice + padding for int4 weight only quantization
- dummy = nn.Linear(256, 321, dtype=dtype, device=device)
- quantize_(dummy, Int4WeightOnlyConfig(version=1))
- # make sure these run without error
- _ = dummy.weight.narrow(0, 0, 64)
- _ = dummy.weight.narrow(1, 0, 128)
-
@common_utils.parametrize("device", [_DEVICE])
@common_utils.parametrize("dtype", [torch.float16, torch.bfloat16])
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
@@ -452,58 +401,6 @@ def test_matmul(self, device, dtype):
# make sure it runs
torch.matmul(x, w.t())
- @common_utils.parametrize("device", [_DEVICE])
- @common_utils.parametrize("dtype", [torch.bfloat16])
- @skip_if_no_cuda()
- @skip_if_rocm("ROCm enablement in progress")
- def test_slice_and_copy_int4wo(self, device, dtype):
- l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
- l.weight = torch.nn.Parameter(
- torch.zeros(1024, 1024, dtype=torch.bfloat16, device=_DEVICE)
- )
- quantize_(l, Int4WeightOnlyConfig(version=1))
- param = l.weight
- param_data = param.data
- param_data = param_data.narrow(0, 0, 512)
- assert (
- param.data.tensor_impl.packed_weight.data_ptr()
- == param_data.tensor_impl.packed_weight.data_ptr()
- )
- assert (
- param.data.tensor_impl.scale_and_zero.data_ptr()
- == param_data.tensor_impl.scale_and_zero.data_ptr()
- )
- assert param.data.dequantize()[0][0] == 0
-
- # dummy_l has random input (shouldn't be 0)
- dummy_l = torch.nn.Linear(1024, 1024).to(_DEVICE).to(torch.bfloat16)
- quantize_(dummy_l, Int4WeightOnlyConfig(version=1))
- quantized = dummy_l.weight
- quantized = quantized.narrow(0, 0, 512)
-
- param_data.copy_(quantized)
-
- # making sure param.data is updated
- assert param.data.dequantize()[0][0] != 0
-
- @common_utils.parametrize("device", [_DEVICE])
- @common_utils.parametrize("dtype", [torch.bfloat16])
- @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
- @skip_if_rocm("ROCm enablement in progress")
- def test_mm_int4wo(self, device, dtype):
- weight = torch.randn(512, 1024).to(device).to(dtype)
- weight = weight.t()
-
- l = torch.nn.Linear(512, 1024).to(device).to(dtype)
- l.weight = torch.nn.Parameter(weight)
- quantize_(l, Int4WeightOnlyConfig(version=1))
- # weight shape: 1024 x 512
- weight = l.weight
-
- input = torch.randn(1, 512, device=device, dtype=dtype)
- # make sure it runs
- torch.nn.functional.linear(input, weight)
-
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)
diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py
index 09bdfa8e61..3cd924f0c9 100644
--- a/test/hqq/test_hqq_affine.py
+++ b/test/hqq/test_hqq_affine.py
@@ -8,13 +8,11 @@
import torch
from torchao.quantization import (
- Int4WeightOnlyConfig,
MappingType,
UIntXWeightOnlyConfig,
ZeroPointDomain,
quantize_,
)
-from torchao.testing.utils import skip_if_rocm
cuda_available = torch.cuda.is_available()
@@ -54,12 +52,7 @@ def _eval_hqq(dtype):
in_features=in_features, out_features=out_features, bias=False
)
dummy_linear.weight.data = W
- if dtype == torch.uint4:
- config = Int4WeightOnlyConfig(
- group_size=max(block_size), use_hqq=True, version=1
- )
- else:
- config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True)
+ config = UIntXWeightOnlyConfig(dtype, group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
@@ -113,14 +106,6 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)
- @skip_if_rocm("ROCm enablement in progress")
- def test_hqq_plain_4bit(self):
- self._test_hqq(
- dtype=torch.uint4,
- ref_dequantize_error=0.000487,
- ref_dot_product_error=0.001472,
- )
-
def test_hqq_plain_3bit(self):
self._test_hqq(
dtype=torch.uint3,
diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py
index 6ca3a8a3c5..6feccfe2f5 100644
--- a/test/integration/test_integration.py
+++ b/test/integration/test_integration.py
@@ -122,30 +122,6 @@ def _int8da_int8w_api(
)
-def _int4wo_api(mod, use_hqq=False):
- if check_cpu_version(next(mod.parameters()).device):
- quantize_(
- mod,
- Int4WeightOnlyConfig(
- layout=Int4CPULayout(),
- use_hqq=use_hqq,
- set_inductor_config=False,
- version=1,
- ),
- )
- unwrap_tensor_subclass(mod)
- elif check_xpu_version(next(mod.parameters()).device):
- quantize_(
- mod,
- Int4WeightOnlyConfig(
- layout=Int4XPULayout(), set_inductor_config=False, version=1
- ),
- )
- unwrap_tensor_subclass(mod)
- else:
- quantize_(mod, Int4WeightOnlyConfig(set_inductor_config=False, version=1))
-
-
def _int8da_int4w_api(mod):
quantize_(mod, Int8DynamicActivationInt4WeightConfig(set_inductor_config=False))
@@ -154,7 +130,6 @@ def _int8da_int4w_api(mod):
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
_int8da_int8w_api,
- _int4wo_api,
]
@@ -622,32 +597,6 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):
_int8wo_api, device, 40, test_dtype=dtype
)
- @parameterized.expand(COMMON_DEVICE_DTYPE)
- @skip_if_xpu("XPU enablement in progress")
- def test_int4_weight_only_quant_subclass_api(self, device, dtype):
- if dtype != torch.bfloat16:
- self.skipTest(f"Fails for {dtype}")
- for test_shape in [(16, 1024, 16)] + (
- [(1, 1024, 256)] if device == _DEVICE else []
- ):
- self._test_lin_weight_subclass_api_impl(
- _int4wo_api, device, 15, test_shape=test_shape, test_dtype=dtype
- )
-
- @parameterized.expand(COMMON_DEVICE_DTYPE)
- @skip_if_xpu("XPU enablement in progress")
- def test_int4_weight_only_hqq_quant_subclass_api(self, device, dtype):
- if dtype != torch.bfloat16:
- self.skipTest(f"Fails for {dtype}")
- for test_shape in [(16, 1024, 16), (1, 1024, 256)]:
- api = partial(
- _int4wo_api,
- use_hqq=True,
- )
- self._test_lin_weight_subclass_api_impl(
- api, device, 15, test_shape=test_shape, test_dtype=dtype
- )
-
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_gemlite_layout(self, device, dtype):
@@ -943,13 +892,6 @@ def test_save_load_int8woqtensors(self, device, dtype):
undo_recommended_configs()
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
- @parameterized.expand(COMMON_DEVICE_DTYPE)
- @torch.no_grad()
- def test_save_load_int4woqtensors(self, device, dtype):
- if dtype != torch.bfloat16:
- self.skipTest(f"Fails for {dtype}")
- self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype)
-
class UtilsUnitTest(unittest.TestCase):
def test_shape_logger(self):
diff --git a/test/integration/test_load_and_run_checkpoint.py b/test/integration/test_load_and_run_checkpoint.py
index 2cd0465b92..be6866c059 100644
--- a/test/integration/test_load_and_run_checkpoint.py
+++ b/test/integration/test_load_and_run_checkpoint.py
@@ -28,23 +28,9 @@
# high precision model, used for testing config deprecation warning
_HIGH_PRECISION_MODEL = "facebook/opt-125m"
-_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
- # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
- (
- "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev",
- 1,
- "Int4WeightOnlyConfig",
- ),
-]
+_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = []
-_DEPRECATED_MODEL_INFO = [
- # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
- (
- "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev",
- 1,
- "Int4WeightOnlyConfig",
- ),
-]
+_DEPRECATED_MODEL_INFO = []
_SINGLE_LINEAR_MODEL_INFO = [
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
diff --git a/test/quantization/test_gptq.py b/test/quantization/test_gptq.py
index 7774c3d7f7..5efc1b9481 100644
--- a/test/quantization/test_gptq.py
+++ b/test/quantization/test_gptq.py
@@ -11,13 +11,10 @@
from torch.testing._internal.common_utils import TestCase
from torchao._models.llama.model import (
- ModelArgs,
Transformer,
prepare_inputs_for_model,
)
from torchao._models.llama.tokenizer import get_tokenizer
-from torchao.quantization import Int4WeightOnlyConfig, quantize_
-from torchao.quantization.utils import compute_error
from torchao.utils import get_current_accelerator_device
torch.manual_seed(0)
@@ -162,55 +159,6 @@ def test_multitensor_input_recorder(self):
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
self.assertEqual(MT_input[3], torch.float)
- @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
- def test_gptq_with_input_recorder(self):
- from torchao.quantization.GPTQ import (
- Int4WeightOnlyGPTQQuantizer,
- MultiTensorInputRecorder,
- )
-
- torch.set_default_dtype(torch.bfloat16)
-
- config = ModelArgs(n_layer=2)
-
- with torch.device(_DEVICE):
- model = Transformer(config)
- model.setup_caches(max_batch_size=2, max_seq_length=100)
- idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
- test_input = prepare_inputs_for_model(idx[0])
- import copy
-
- model2 = copy.deepcopy(model)
- out = model(*test_input)
- quantize_(model2, Int4WeightOnlyConfig(version=1))
-
- outq = model2(*test_input)
- del model2
-
- input_recorder = MultiTensorInputRecorder()
- for i in range(10):
- input = prepare_inputs_for_model(idx[i])
- input_recorder(*input)
-
- args = input_recorder.get_recorded_inputs()
-
- if _DEVICE.type == "xpu":
- from torchao.dtypes import Int4XPULayout
-
- quantizer = Int4WeightOnlyGPTQQuantizer(
- device=torch.device("xpu"), layout=Int4XPULayout()
- )
- else:
- quantizer = Int4WeightOnlyGPTQQuantizer()
-
- quantizer.quantize(model, *args)
-
- outgptq = model(*test_input)
-
- self.assertGreater(compute_error(outgptq, out), 30)
- self.assertGreater(compute_error(outgptq, out), compute_error(outq, out))
- torch.set_default_dtype(torch.float32)
-
if __name__ == "__main__":
unittest.main()
diff --git a/test/quantization/test_moe_quant.py b/test/quantization/test_moe_quant.py
index 61000babc1..e5a903e445 100644
--- a/test/quantization/test_moe_quant.py
+++ b/test/quantization/test_moe_quant.py
@@ -12,7 +12,6 @@
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
-from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl
from torchao.prototype.moe_quant.quantizable_moe_modules import (
MOEFeedForwardAOQuantizable,
)
@@ -26,7 +25,6 @@
AffineQuantizedTensor,
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
LinearActivationQuantizedTensor,
@@ -109,51 +107,6 @@ def _test_impl_moe_quant(
self.assertGreaterEqual(compute_error(out_q, out), 10)
self.assertGreaterEqual(compute_error(out_qc, out), 10)
- @parameterized.expand(
- [
- ("single_token", 1, False),
- ("multiple_tokens", 8, False),
- ]
- )
- def test_int4wo_fake_dim(self, name, num_tokens, fullgraph):
- if not torch.cuda.is_available():
- self.skipTest("Need CUDA available")
-
- config = MoEQuantConfig(
- Int4WeightOnlyConfig(version=1),
- use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
- )
- tensor_impl_class = TensorCoreTiledAQTTensorImpl
-
- self._test_impl_moe_quant(
- config=config,
- num_tokens=num_tokens,
- tensor_impl_class=tensor_impl_class,
- fullgraph=fullgraph,
- )
-
- @parameterized.expand(
- [
- ("single_token", 1, True),
- ("multiple_tokens", 8, False),
- ]
- )
- def test_int4wo_base(self, name, num_tokens, fullgraph):
- if not torch.cuda.is_available():
- self.skipTest("Need CUDA available")
- if not is_sm_at_least_90():
- self.skipTest("Requires CUDA capability >= 9.0")
-
- config = MoEQuantConfig(Int4WeightOnlyConfig(version=1))
- tensor_impl_class = TensorCoreTiledAQTTensorImpl
-
- self._test_impl_moe_quant(
- config=config,
- num_tokens=num_tokens,
- tensor_impl_class=tensor_impl_class,
- fullgraph=fullgraph,
- )
-
@parameterized.expand(
[
("single_token", 1, False),
diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py
index 1bcc5e3349..0d054405d6 100644
--- a/test/quantization/test_qat.py
+++ b/test/quantization/test_qat.py
@@ -2064,14 +2064,6 @@ def test_infer_int4_weight_only_config(self):
_infer_fake_quantize_configs,
)
- base_config = Int4WeightOnlyConfig(version=1)
- (act_config, weight_config) = _infer_fake_quantize_configs(base_config)
- self.assertIsNone(act_config)
- self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
- self.assertEqual(weight_config.dtype, torch.uint4)
- self.assertEqual(weight_config.group_size, 128)
- self.assertFalse(weight_config.is_symmetric)
-
base_config = Int4WeightOnlyConfig(version=2)
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
self.assertIsNone(act_config)
diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index d6efc8a906..bc09f199a7 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -28,8 +28,6 @@
from torchao._models.llama.tokenizer import get_tokenizer
from torchao.dtypes import (
AffineQuantizedTensor,
- Int4CPULayout,
- Int4XPULayout,
PlainLayout,
TensorCoreTiledLayout,
)
@@ -73,7 +71,6 @@
get_current_accelerator_device,
is_sm_at_least_89,
is_sm_at_least_90,
- torch_version_at_least,
unwrap_tensor_subclass,
)
@@ -219,34 +216,6 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
compiled = m(*example_inputs)
torch.testing.assert_close(quantized, compiled, atol=0, rtol=0)
- @unittest.skipIf(not torch.xpu.is_available(), "Need XPU available")
- @unittest.skipIf(not torch_version_at_least("2.8.0"), "only works for torch 2.8+")
- def test_int4_wo_quant_save_load(self):
- m = ToyLinearModel().eval().cpu()
-
- def api(model):
- quantize_(model, Int4WeightOnlyConfig(layout=Int4XPULayout(), version=1))
- unwrap_tensor_subclass(model)
-
- api(m)
-
- example_inputs = m.example_inputs()
- ref = m(*example_inputs)
- with tempfile.NamedTemporaryFile() as f:
- torch.save(m.state_dict(), f)
- f.seek(0)
- state_dict = torch.load(f)
-
- m2 = ToyLinearModel().eval().cpu()
- api(m2)
-
- m2.load_state_dict(state_dict)
- m2 = m2.to(device="xpu")
- example_inputs = map(lambda x: x.xpu(), example_inputs)
- res = m2(*example_inputs)
-
- torch.testing.assert_close(ref, res.cpu())
-
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_int8_wo_quant_save_load(self):
m = ToyLinearModel().eval().cpu()
@@ -524,37 +493,11 @@ def reset_memory():
assert param.device.type == _DEVICE.type
self.assertLess(memory_streaming, memory_baseline)
- @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
- @common_utils.parametrize("x_dim", [2, 3])
- @common_utils.parametrize("use_hqq", [True, False])
- def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
- device = "cpu"
- m = ToyLinearModel().eval().to(dtype).to(device)
- example_inputs = m.example_inputs(dtype=dtype, device=device)
- if x_dim == 3:
- example_inputs = (example_inputs[0].unsqueeze(0),)
-
- with torch.no_grad():
- quantize_(
- m,
- Int4WeightOnlyConfig(
- group_size=32, layout=Int4CPULayout(), use_hqq=use_hqq, version=1
- ),
- )
- # ensure the expected op is in the code
- _, code = torch._inductor.utils.run_and_get_code(
- torch.compile(m, fullgraph=True, dynamic=True),
- *example_inputs,
- )
- assert "_weight_int4pack_mm_for_cpu" in code[0]
- assert "aten.mm.default" not in code[0]
-
# TODO(#1690): move to new config names
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
@common_utils.parametrize(
"config",
[
- Int4WeightOnlyConfig(version=1),
Float8WeightOnlyConfig(),
Float8DynamicActivationFloat8WeightConfig(),
Float8StaticActivationFloat8WeightConfig(scale=torch.tensor([1.0])),
@@ -621,7 +564,7 @@ def test_workflow_e2e_numerics(self, config):
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_module_fqn_to_config_default(self):
- config1 = Int4WeightOnlyConfig(group_size=32, version=1)
+ config1 = Int4WeightOnlyConfig(group_size=32)
config2 = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"_default": config1, "linear2": config2})
model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16)
@@ -635,7 +578,7 @@ def test_module_fqn_to_config_default(self):
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_module_fqn_to_config_module_name(self):
- config1 = Int4WeightOnlyConfig(group_size=32, version=1)
+ config1 = Int4WeightOnlyConfig(group_size=32)
config2 = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"linear1": config1, "linear2": config2})
model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16)
@@ -774,7 +717,7 @@ def test_module_fqn_to_config_embedding_linear(self):
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
def test_module_fqn_to_config_skip(self):
- config1 = Int4WeightOnlyConfig(group_size=32, version=1)
+ config1 = Int4WeightOnlyConfig(group_size=32)
config = ModuleFqnToConfig({"_default": config1, "linear2": None})
model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16)
@@ -784,25 +727,6 @@ def test_module_fqn_to_config_skip(self):
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
assert not isinstance(model.linear2.weight, AffineQuantizedTensor)
- @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
- def test_int4wo_cuda_serialization(self):
- config = Int4WeightOnlyConfig(group_size=32, version=1)
- model = ToyLinearModel().to(_DEVICE).to(dtype=torch.bfloat16)
- # quantize in cuda
- quantize_(model, config)
- example_inputs = model.example_inputs(device=_DEVICE, dtype=torch.bfloat16)
- model(*example_inputs)
- with tempfile.NamedTemporaryFile() as ckpt:
- # save checkpoint in cuda
- torch.save(model.state_dict(), ckpt)
- # load checkpoint on cpu then move checkpoint to cuda
- # This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
- sd = torch.load(ckpt.name, weights_only=False, map_location="cpu")
- for k, v in sd.items():
- sd[k] = v.to(_DEVICE)
- # load state_dict in cuda
- model.load_state_dict(sd, assign=True)
-
def test_config_deprecation(self):
"""
Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings.
diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py
index dd07c31172..786b5f2618 100644
--- a/test/sparsity/test_marlin.py
+++ b/test/sparsity/test_marlin.py
@@ -3,23 +3,18 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
-import copy
import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
-from torchao.dtypes import MarlinSparseLayout
-from torchao.quantization.quant_api import Int4WeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
quantize_affine,
)
from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24
-from torchao.sparsity.sparse_api import apply_fake_sparsity
-from torchao.testing.utils import skip_if_rocm
class SparseMarlin24(TestCase):
@@ -42,47 +37,6 @@ def setUp(self):
for param in self.model.parameters():
param.requires_grad = False
- @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
- @skip_if_rocm("ROCm enablement in progress")
- def test_quant_sparse_marlin_layout_eager(self):
- apply_fake_sparsity(self.model)
- model_copy = copy.deepcopy(self.model)
-
- # Quantized
- quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1))
- dense_result = model_copy(self.input.bfloat16()).half()
-
- # Sparse + quantized
- quantize_(
- self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
- )
- sparse_result = self.model(self.input)
- assert torch.allclose(dense_result, sparse_result, atol=3e-1), (
- "Results are not close"
- )
-
- @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
- @skip_if_rocm("ROCm enablement in progress")
- def test_quant_sparse_marlin_layout_compile(self):
- apply_fake_sparsity(self.model)
- model_copy = copy.deepcopy(self.model)
-
- # Quantized
- quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1))
- model_copy.foward = torch.compile(model_copy.forward, fullgraph=True)
- dense_result = model_copy(self.input.bfloat16()).half()
-
- # Sparse + quantized
- quantize_(
- self.model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
- )
- self.model.forward = torch.compile(self.model.forward, fullgraph=True)
- sparse_result = self.model(self.input)
-
- assert torch.allclose(dense_result, sparse_result, atol=3e-1), (
- "Results are not close"
- )
-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
diff --git a/test/sparsity/test_sparse_api.py b/test/sparsity/test_sparse_api.py
index 66cd032a9a..a8caa984e4 100644
--- a/test/sparsity/test_sparse_api.py
+++ b/test/sparsity/test_sparse_api.py
@@ -11,13 +11,12 @@
from torch import nn
from torch.testing._internal import common_utils
-from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
+from torchao.dtypes import SemiSparseLayout
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.quant_api import (
- Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
quantize_,
)
@@ -94,38 +93,6 @@ def test_quant_semi_sparse(self, compile):
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
- @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
- @common_utils.parametrize("compile", [True, False])
- def test_sparse_marlin(self, compile):
- if not torch.backends.cusparselt.is_available():
- self.skipTest("Need cuSPARSELt")
-
- input = torch.rand((256, 256)).half().cuda()
- model = (
- nn.Sequential(
- nn.Linear(256, 1024),
- nn.Linear(1024, 256),
- )
- .half()
- .cuda()
- .eval()
- )
-
- apply_fake_sparsity(model)
- model_copy = copy.deepcopy(model)
-
- # Quantized
- quantize_(model_copy.bfloat16(), Int4WeightOnlyConfig(version=1))
- dense_result = model_copy(input.bfloat16()).half()
-
- # Sparse + quantized
- quantize_(model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1))
- if compile:
- model = torch.compile(model)
- sparse_result = model(input)
-
- torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)
-
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py
index 002045215c..c974d7ae08 100644
--- a/torchao/_models/llama/eval.py
+++ b/torchao/_models/llama/eval.py
@@ -24,7 +24,6 @@
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
FPXWeightOnlyConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PerBlock,
@@ -84,19 +83,6 @@ def run_evaluation(
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if "fp6" in quantization:
quantize_(model, FPXWeightOnlyConfig(3, 2))
- if "int4wo" in quantization and not "gptq" in quantization:
- if "hqq" in quantization:
- use_hqq = True
- else:
- use_hqq = False
- groupsize = int(quantization.split("-")[1])
- assert groupsize in [32, 64, 128, 256], (
- f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
- )
- quantize_(
- model.to(device),
- Int4WeightOnlyConfig(group_size=groupsize, use_hqq=use_hqq, version=1),
- )
if "uintx" in quantization:
# uintx-nbits-groupsize
# "uintx-2-64"
@@ -119,12 +105,6 @@ def run_evaluation(
dtype = _NBITS_TO_DTYPE[nbits]
group_size = int(_quant_args[2])
quantize_(model, UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq))
- if "marlin" in quantization:
- from torchao.dtypes import MarlinSparseLayout
-
- quantize_(
- model, Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1)
- )
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import LMEvalInputRecorder
diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py
index fc3d371139..84350bb474 100644
--- a/torchao/_models/llama/generate.py
+++ b/torchao/_models/llama/generate.py
@@ -345,7 +345,6 @@ def ffn_or_attn_only(mod, fqn):
FPXWeightOnlyConfig,
GemliteUIntXWeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
@@ -417,24 +416,7 @@ def ffn_or_attn_only(mod, fqn):
)
else:
quantize_(model, Int8DynamicActivationInt8WeightConfig())
- if "int4wo" in quantization:
- use_hqq = False
- if "hqq" in quantization:
- use_hqq = True
- group_size = int(quantization.split("-")[1])
- assert group_size in [
- 32,
- 64,
- 128,
- 256,
- ], (
- f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
- )
- quantize_(
- model,
- Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq, version=1),
- )
- elif "int4dq-" in quantization:
+ if "int4dq-" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout
nbits = int(quantization.removeprefix("int4dq-"))
@@ -471,14 +453,6 @@ def ffn_or_attn_only(mod, fqn):
layout=MarlinQQQLayout(),
),
)
- elif "semi" in sparsity:
- from torchao.dtypes import MarlinSparseLayout
-
- quantize_(
- model,
- Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1),
- filter_fn=ffn_or_attn_only,
- )
if "fp6" in quantization:
quantize_(model, FPXWeightOnlyConfig(3, 2))
elif "embed-int8wo" in quantization:
diff --git a/torchao/_models/mixtral-moe/generate.py b/torchao/_models/mixtral-moe/generate.py
index 39ee6a4dcb..660d3c8316 100644
--- a/torchao/_models/mixtral-moe/generate.py
+++ b/torchao/_models/mixtral-moe/generate.py
@@ -244,7 +244,6 @@ def main(
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
- Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
@@ -273,15 +272,6 @@ def main(
use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
)
- elif "int4wo-base" in moe_quant:
- config = MoEQuantConfig(Int4WeightOnlyConfig(version=1))
-
- elif "int4wo" in moe_quant:
- config = MoEQuantConfig(
- Int4WeightOnlyConfig(version=1),
- use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE,
- )
-
elif "fp8wo-base" in moe_quant:
config = MoEQuantConfig(Float8WeightOnlyConfig())
diff --git a/torchao/_models/sam/eval_combo.py b/torchao/_models/sam/eval_combo.py
index 467e24a9b6..d52ecef980 100644
--- a/torchao/_models/sam/eval_combo.py
+++ b/torchao/_models/sam/eval_combo.py
@@ -22,7 +22,6 @@
from torchao.dtypes import SemiSparseLayout
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao.quantization import (
- Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
autoquant,
quantize_,
@@ -392,22 +391,6 @@ def mlp_only(mod, name):
mlp_lin1_only,
)
sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only)
- elif compress == "int4_weight_only_sparse":
- # apply sparsify first to set qparams
- apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
- from torchao.dtypes import MarlinSparseLayout
-
- quantize_(
- predictor.model.image_encoder,
- Int8DynamicActivationInt8WeightConfig(),
- attn_only,
- )
- quantize_(
- predictor.model.image_encoder,
- Int4WeightOnlyConfig(layout=MarlinSparseLayout(), version=1),
- mlp_lin1_only,
- )
- sparsify_(predictor.model.image_encoder, semi_sparse_weight(), mlp_lin2_only)
elif compress is not None and "autoquant_v2" in compress:
example_input = torch.randn(
diff --git a/torchao/prototype/autoround/README.md b/torchao/prototype/autoround/README.md
index a67b3be9f0..fcb896ba51 100644
--- a/torchao/prototype/autoround/README.md
+++ b/torchao/prototype/autoround/README.md
@@ -114,7 +114,7 @@ quantize_(model, apply_auto_round(), is_target_module)
| autoround-4bit* | 0.6338 | 0.4566 | 0.7661 | 0.6646 | 0.5688 | 0.7130 |
> [!NOTE]
-> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`Int4WeightOnlyConfig(group_size=128, version=1)`) while leaving the `lm-head` unquantized.
+> - `torchao-int4wo` quantizes the model to 4 bits with a group size of 128 (`Int4WeightOnlyConfig(group_size=128)`) while leaving the `lm-head` unquantized.
> - `auto-round-4bit` uses the deafult configuration from [quick start](#quick-start).
> - `auto-round-4bit*` follows the same settings as `auto-round-4bit`, but with `gradient_accumulate_steps=2` and `batch_size=4`, which accumulating two batches(4 samples per batch) before performing the backward pass.
> - To reproduce results, run `eval_autoround.py` with `AO_USE_DETERMINISTIC_ALGORITHMS=1`.
diff --git a/torchao/prototype/autoround/eval_autoround.py b/torchao/prototype/autoround/eval_autoround.py
index 4f6850be88..59eb5d6093 100644
--- a/torchao/prototype/autoround/eval_autoround.py
+++ b/torchao/prototype/autoround/eval_autoround.py
@@ -105,7 +105,7 @@ def main(args):
quantize_(
model,
- Int4WeightOnlyConfig(group_size=args.group_size, version=1),
+ Int4WeightOnlyConfig(group_size=args.group_size),
filter_fn=filter_fn,
device=model_device,
)
diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py
deleted file mode 100644
index cda96f6b3c..0000000000
--- a/torchao/prototype/hqq/example.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the BSD 3-Clause license found in the
-# LICENSE file in the root directory of this source tree.
-import torch
-
-from torchao.dtypes import PlainLayout, TensorCoreTiledLayout
-from torchao.dtypes.affine_quantized_tensor import (
- to_affine_quantized_intx,
-)
-from torchao.quantization import (
- MappingType,
- ZeroPointDomain,
-)
-
-# Parameters
-device, compute_dtype = "cuda:0", torch.bfloat16
-group_size, axis = 64, 1
-in_features, out_features = 4096, 11800
-
-torch.random.manual_seed(100)
-linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device)
-x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device) / 20.0
-y_ref = linear_layer(x)
-W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
-del linear_layer.weight
-
-################################################################################################
-# AffineQuantizedTensor example
-################################################################################################
-print("-------------------------------------------------------------------")
-print("AffineQuantizedTensor example")
-print("-------------------------------------------------------------------")
-mapping_type = MappingType.ASYMMETRIC
-block_size = (1, group_size)
-target_dtype = torch.uint8 # until sub-byte dtypes are supported
-preserve_zero = False
-zero_point_domain = ZeroPointDomain.FLOAT
-zero_point_dtype = compute_dtype
-_layout = PlainLayout()
-
-for nbits in list(range(2, 9))[::-1]:
- print(
- "------------------------------------------------------------------------------"
- )
- q_tensor_default = to_affine_quantized_intx(
- input_float=W,
- mapping_type=mapping_type,
- block_size=block_size,
- target_dtype=target_dtype,
- quant_min=0,
- quant_max=2**nbits - 1,
- zero_point_domain=zero_point_domain,
- preserve_zero=preserve_zero,
- _layout=_layout,
- )
-
- linear_layer.weight = q_tensor_default
- print(
- "nbits",
- nbits,
- "| Default dequantization error",
- (W - q_tensor_default.dequantize()).abs().mean().item(),
- )
- print(
- "nbits",
- nbits,
- "| Default Dot product error",
- (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item(),
- )
- # nbits 4 | Default dequantization error 0.001953125
- # nbits 4 | Default Dot product error 0.005926903802901506
-
- q_tensor_hqq = to_affine_quantized_intx(
- input_float=W,
- mapping_type=mapping_type,
- block_size=block_size,
- target_dtype=target_dtype,
- quant_min=0,
- quant_max=2**nbits - 1,
- zero_point_domain=zero_point_domain,
- preserve_zero=preserve_zero,
- _layout=_layout,
- use_hqq=True,
- )
-
- linear_layer.weight = q_tensor_hqq
- print(
- "nbits",
- nbits,
- "| HQQ dequantization error",
- (W - q_tensor_hqq.dequantize()).abs().mean().item(),
- )
- print(
- "nbits",
- nbits,
- "| HQQ Dot product error",
- (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item(),
- )
- # nbits 4 | HQQ dequantization error 0.0004863739013671875
- # nbits 4 | HQQ Dot product error 0.0014713306445628405
-
-################################################################################################
-# quant_api example
-################################################################################################
-print("-------------------------------------------------------------------")
-print("Quant API example")
-print("-------------------------------------------------------------------")
-
-from torchao.quantization.quant_api import Int4WeightOnlyConfig
-
-nbits = 4
-target_dtype = torch.int32
-inner_k_tiles = 8
-_layout = TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)
-
-int4_weight_only_patch_fct = Int4WeightOnlyConfig(
- group_size=group_size, inner_k_tiles=inner_k_tiles, version=1
-)
-linear_layer_default = torch.nn.Linear(
- in_features, out_features, bias=False, device=device
-)
-linear_layer_default.weight.data = W.clone()
-linear_layer_default = int4_weight_only_patch_fct(linear_layer_default)
-print(
- "nbits",
- nbits,
- "| Default dequantization error",
- (W - linear_layer_default(torch.eye(W.shape[1], dtype=W.dtype, device=W.device)).T)
- .abs()
- .mean()
- .item(),
-)
-print(
- "nbits",
- nbits,
- "| Default Dot product error",
- (y_ref - linear_layer_default(x.to(compute_dtype))).abs().mean().item(),
-)
-# nbits 4 | Default dequantization error 0.000492095947265625
-# nbits 4 | Default Dot product error 0.0015244047390297055
-
-
-q_tensor_hqq = to_affine_quantized_intx(
- input_float=W,
- mapping_type=mapping_type,
- block_size=block_size,
- target_dtype=target_dtype,
- quant_min=0,
- quant_max=2**nbits - 1,
- zero_point_domain=zero_point_domain,
- preserve_zero=preserve_zero,
- _layout=_layout,
- use_hqq=True,
-)
-linear_layer.weight = q_tensor_hqq
-print(
- "nbits",
- nbits,
- "| HQQ dequantization error",
- (W - linear_layer(torch.eye(W.shape[1], dtype=W.dtype, device=W.device)).T)
- .abs()
- .mean()
- .item(),
-)
-print(
- "nbits",
- nbits,
- "| HQQ Dot product error",
- (y_ref - linear_layer(x.to(compute_dtype))).abs().mean().item(),
-)
-# nbits 4 | HQQ dequantization error 0.0004863739013671875
-# nbits 4 | HQQ Dot product error 0.0014699687017127872
diff --git a/torchao/prototype/moe_quant/llama4_quant.py b/torchao/prototype/moe_quant/llama4_quant.py
index ae6abccea5..81c0ba580f 100644
--- a/torchao/prototype/moe_quant/llama4_quant.py
+++ b/torchao/prototype/moe_quant/llama4_quant.py
@@ -77,7 +77,7 @@ def convert_fn(module):
quantize_(
model,
- MoEQuantConfig(Int4WeightOnlyConfig(version=1)),
+ MoEQuantConfig(Int4WeightOnlyConfig()),
cond_ffn_filter,
device="cuda",
)
diff --git a/torchao/prototype/quantization/quant_api.py b/torchao/prototype/quantization/quant_api.py
index 851a3597ac..6550026295 100644
--- a/torchao/prototype/quantization/quant_api.py
+++ b/torchao/prototype/quantization/quant_api.py
@@ -458,10 +458,6 @@ def _uintx_weight_only_transform(
block_size = (1, group_size)
if use_hqq:
- if dtype == torch.uint4:
- logger.warning(
- "Recommended to use `Int4WeightOnlyConfig(group_size, use_hqq=True, version=1)` for the best performance"
- )
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
dtype = torch.uint8
eps = None
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 85ab42bbd0..ff1d095176 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -455,7 +455,7 @@ def quantize_(
from torchao.quantization.quant_api import Int4WeightOnlyConfig
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
- quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
+ quantize_(m, Int4WeightOnlyConfig(group_size=32))
"""
torch._C._log_api_usage_once("torchao.quantization.quantize_")
@@ -734,27 +734,12 @@ class Int4WeightOnlyConfig(AOBaseConfig):
`int4_packing_format`: the packing format for int4 tensor, used in version 2 only
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4,
currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only
- `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only
- `use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only
- `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2
- `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only
- `version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 2, see note for more details
-
- Note:
- Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2
-
- For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
- For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config
- less confusing
+ `version`: version of the config to use, default is 2
"""
group_size: int = 128
- layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8)
- use_hqq: bool = False
- zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
set_inductor_config: bool = True
- preserve_zero: Optional[bool] = None
# only used in version >= 2
int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
@@ -773,10 +758,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
# for now, make these local variables to allow the rest of the function
# to be a direct copy-paste
group_size = config.group_size
- layout = config.layout
- use_hqq = config.use_hqq
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
- zero_point_domain = config.zero_point_domain
int4_packing_format = config.int4_packing_format
if weight.shape[-1] % group_size != 0:
@@ -787,108 +769,49 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
- if config.version == 2:
- block_size = list(block_size)
-
- if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
- assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
- f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
- f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
- )
-
- if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
- new_weight = Int4PreshuffledTensor.from_hp(
- weight,
- block_size,
- activation_dtype=torch.bfloat16,
- )
- return new_weight
- elif int4_packing_format == Int4PackingFormat.PLAIN:
- new_weight = Int4Tensor.from_hp(
- weight,
- block_size,
- )
- return new_weight
- elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
- new_weight = Int4PlainInt32Tensor.from_hp(
- weight,
- block_size,
- )
- return new_weight
- elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
- new_weight = Int4MarlinSparseTensor.from_hp(
- weight,
- block_size,
- )
- return new_weight
- elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
- new_weight = Int4TilePackedTo4dTensor.from_hp(
- weight,
- block_size,
- int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
- )
- return new_weight
- else:
- raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")
-
- assert config.version == 1
-
- warnings.warn(
- "Config Deprecation: version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2948 for more details"
- )
- mapping_type = MappingType.ASYMMETRIC
- target_dtype = torch.int32
- quant_min = 0
- quant_max = 15
- eps = 1e-6
- zero_point_dtype = (
- weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16
- )
+ assert config.version == 2
+ block_size = list(block_size)
- # nonlocal zero_point_domain
- assert type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys(), (
- f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
- )
- if zero_point_domain == ZeroPointDomain.NONE:
- # the first value is the default one
- zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
- else:
- assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)], (
- f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
+ if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
+ assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
+ f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
+ f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
)
- if zero_point_domain == ZeroPointDomain.INT and isinstance(layout, Int4XPULayout):
- zero_point_dtype = torch.int32
-
- preserve_zero = (
- config.preserve_zero
- if config.preserve_zero is not None
- else LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
- )
- # Sparse Marlin only supports symmetric quantization.
- # NOTE: If we start having lots of layouts that require different configurations,
- # we should consider moving this logic somewhere else.
- if isinstance(layout, MarlinSparseLayout):
- mapping_type = MappingType.SYMMETRIC
- assert group_size == 128 or group_size == weight.shape[-1], (
- f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
+ if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
+ new_weight = Int4PreshuffledTensor.from_hp(
+ weight,
+ block_size,
+ activation_dtype=torch.bfloat16,
)
-
- new_weight = to_affine_quantized_intx(
- weight,
- mapping_type,
- block_size,
- target_dtype,
- quant_min,
- quant_max,
- eps,
- zero_point_dtype=zero_point_dtype,
- preserve_zero=preserve_zero,
- zero_point_domain=zero_point_domain,
- _layout=layout,
- use_hqq=use_hqq,
- )
- return new_weight
+ return new_weight
+ elif int4_packing_format == Int4PackingFormat.PLAIN:
+ new_weight = Int4Tensor.from_hp(
+ weight,
+ block_size,
+ )
+ return new_weight
+ elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
+ new_weight = Int4PlainInt32Tensor.from_hp(
+ weight,
+ block_size,
+ )
+ return new_weight
+ elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
+ new_weight = Int4MarlinSparseTensor.from_hp(
+ weight,
+ block_size,
+ )
+ return new_weight
+ elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
+ new_weight = Int4TilePackedTo4dTensor.from_hp(
+ weight,
+ block_size,
+ int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
+ )
+ return new_weight
+ else:
+ raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")
@register_quantize_module_handler(Int4WeightOnlyConfig)