diff --git a/test/prototype/test_gptq.py b/test/prototype/test_gptq.py new file mode 100644 index 0000000000..7fa64396ec --- /dev/null +++ b/test/prototype/test_gptq.py @@ -0,0 +1,453 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn.functional as F + +from torchao.prototype.gptq import ( + GPTQConfig, + gptq_quantize, +) +from torchao.prototype.gptq.observer import ObserverTensor +from torchao.quantization import Int4WeightOnlyConfig, quantize_ + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, k, bias=False) + self.linear2 = torch.nn.Linear(k, n, bias=False) + self.linear3 = torch.nn.Linear(n, n, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = F.relu(x) + x = self.linear3(x) + return x + + +class TestObserverTensor(unittest.TestCase): + """Test suite for ObserverTensor functionality.""" + + def test_observer_tensor_creation(self): + """Test that ObserverTensor.from_hp() creates tensor with correct properties.""" + weight = torch.randn(32, 64, dtype=torch.float32, device="cuda") + observer = ObserverTensor.from_hp(weight) + + # Check it's an ObserverTensor + self.assertIsInstance(observer, ObserverTensor) + + # Check shape matches + self.assertEqual(observer.shape, weight.shape) + + # Check dtype and device match + self.assertEqual(observer.dtype, weight.dtype) + self.assertEqual(observer.device, weight.device) + + # Check hp_data is stored correctly + torch.testing.assert_close(observer.hp_data, weight) + + # Check hessian is initialized as None + self.assertIsNone(observer.hessian) + + # Check total_batches is initialized as 0 + self.assertEqual(observer.total_batches, 0) + + def test_observer_tensor_attributes(self): + """Test ObserverTensor attributes are correctly set.""" + weight = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda") + observer = ObserverTensor.from_hp(weight) + + # Test hp_data attribute + self.assertTrue(hasattr(observer, "hp_data")) + self.assertIsInstance(observer.hp_data, torch.Tensor) + + # Test hessian attribute + self.assertTrue(hasattr(observer, "hessian")) + self.assertIsNone(observer.hessian) + + # Test total_batches attribute + self.assertTrue(hasattr(observer, "total_batches")) + self.assertEqual(observer.total_batches, 0) + + # Test update method exists + self.assertTrue(hasattr(observer, "update")) + self.assertTrue(callable(observer.update)) + + def test_linear_operation_with_observer(self): + """Test F.linear with ObserverTensor updates Hessian correctly.""" + batch_size = 4 + in_features = 64 + out_features = 32 + + # Create weight as ObserverTensor + weight = torch.randn( + out_features, in_features, dtype=torch.float32, device="cuda" + ) + observer_weight = ObserverTensor.from_hp(weight) + + # Create input + input_tensor = torch.randn( + batch_size, in_features, dtype=torch.float32, device="cuda" + ) + + # Perform linear operation + output = F.linear(input_tensor, observer_weight) + + # Check output shape is correct + self.assertEqual(output.shape, (batch_size, out_features)) + + # Check that Hessian was initialized and updated + self.assertIsNotNone(observer_weight.hessian) + self.assertEqual(observer_weight.hessian.shape, (in_features, in_features)) + self.assertEqual(observer_weight.total_batches, 1) + + # Verify output is correct + expected_output = F.linear(input_tensor, weight) + torch.testing.assert_close(output, expected_output) + + def test_multiple_observations(self): + """Test that Hessian updates incrementally across multiple forward passes.""" + out_features = 16 + in_features = 32 + + weight = torch.randn( + out_features, in_features, dtype=torch.float32, device="cuda" + ) + observer_weight = ObserverTensor.from_hp(weight) + + num_passes = 5 + total_samples = 0 + + # Perform multiple forward passes + for i in range(num_passes): + batch_size = 2 + input_tensor = torch.randn( + batch_size, in_features, dtype=torch.float32, device="cuda" + ) + total_samples += 1 + _ = F.linear(input_tensor, observer_weight) + + # Check that Hessian was created and updated + self.assertIsNotNone(observer_weight.hessian) + self.assertEqual(observer_weight.hessian.shape, (in_features, in_features)) + + # Check total_batches matches total samples + self.assertEqual(observer_weight.total_batches, total_samples) + + def test_bmm_operation_with_observer(self): + """Test torch.bmm with ObserverTensor updates Hessian correctly.""" + batch = 4 + m = 8 + n = 16 + k = 12 + + # Create input and weight tensors + input_tensor = torch.randn(batch, m, k, dtype=torch.float32, device="cuda") + weight = torch.randn(batch, k, n, dtype=torch.float32, device="cuda") + observer_weight = ObserverTensor.from_hp(weight) + + # Perform bmm operation + output = torch.bmm(input_tensor, observer_weight) + + # Check output shape + self.assertEqual(output.shape, (batch, m, n)) + + # Check Hessian was initialized and updated + self.assertIsNotNone(observer_weight.hessian) + # For bmm with batch dimension, the Hessian is computed on the last dimension + self.assertEqual(observer_weight.total_batches, batch) + + # Verify output is correct + expected_output = torch.bmm(input_tensor, weight) + torch.testing.assert_close(output, expected_output) + + def test_observer_config_transform(self): + """Test GPTQConfig wraps module weights correctly.""" + # Create a simple linear layer + linear = torch.nn.Linear(64, 32, bias=False).cuda() + original_weight = linear.weight.data.clone() + + # Apply GPTQConfig with observe step + quantize_(linear, GPTQConfig(step="observe", group_size=128)) + + # Check weight is now an ObserverTensor + self.assertIsInstance(linear.weight, ObserverTensor) + + # Check hp_data matches original weight + torch.testing.assert_close(linear.weight.hp_data, original_weight) + + # Check hessian is None initially + self.assertIsNone(linear.weight.hessian) + self.assertEqual(linear.weight.total_batches, 0) + + # Perform a forward pass + input_tensor = torch.randn(4, 64, dtype=torch.float32, device="cuda") + output = linear(input_tensor) + + # Check Hessian was initialized after forward pass + self.assertIsNotNone(linear.weight.hessian) + self.assertEqual(linear.weight.total_batches, 1) + + # Check output shape + self.assertEqual(output.shape, (4, 32)) + + def test_hessian_incremental_update(self): + """Test that incremental Hessian updates match batch calculation.""" + in_features = 32 + out_features = 16 + + weight = torch.randn( + out_features, in_features, dtype=torch.float32, device="cuda" + ) + + # Create two ObserverTensors - one for incremental, one for batch + observer_incremental = ObserverTensor.from_hp(weight) + + # Collect activations for batch computation + activations = [] + num_batches = 3 + for _ in range(num_batches): + batch_size = 4 + input_tensor = torch.randn( + batch_size, in_features, dtype=torch.float32, device="cuda" + ) + activations.append(input_tensor) + # Update incrementally + _ = F.linear(input_tensor, observer_incremental) + + # Compute Hessian in batch using _calculate_hessian + from torchao.prototype.gptq import _calculate_hessian + + hessian_batch = _calculate_hessian(activations, device="cuda") + + # Compare incremental vs batch + self.assertIsNotNone(observer_incremental.hessian) + torch.testing.assert_close( + observer_incremental.hessian, hessian_batch, rtol=1e-4, atol=1e-5 + ) + + +class TestGPTQFlow(unittest.TestCase): + def test_unified_config_two_phase(self): + """Test that GPTQConfig handles both observation and quantization phases.""" + # Create a simple linear layer + linear = torch.nn.Linear(64, 32, bias=False).cuda().to(torch.bfloat16) + original_weight = linear.weight.data.clone() + + # Phase 1: Observation step - wrap as ObserverTensor + observe_config = GPTQConfig( + step="observe", + group_size=128, + ) + quantize_(linear, observe_config) + + # Verify weight is now an ObserverTensor + self.assertIsInstance(linear.weight, ObserverTensor) + torch.testing.assert_close(linear.weight.hp_data, original_weight) + + # Run some forward passes for calibration + for _ in range(10): + input_tensor = torch.randn(4, 64, dtype=torch.bfloat16, device="cuda") + _ = linear(input_tensor) + + # Verify Hessian was computed + self.assertIsNotNone(linear.weight.hessian) + self.assertGreater(linear.weight.total_batches, 0) + + # Phase 2: Convert step - apply GPTQ quantization + convert_config = GPTQConfig( + step="convert", + group_size=128, + ) + quantize_(linear, convert_config) + + # Verify weight is now Int4Tensor (quantized) + from torchao.quantization import Int4Tensor + + self.assertIsInstance(linear.weight, Int4Tensor) + + # Verify it still works + output = linear(input_tensor) + self.assertEqual(output.shape, (4, 32)) + + def test_gptq_quantize_function(self): + """Test gptq_quantize function with synthetic Hessian and weights.""" + torch.manual_seed(42) + + # Create synthetic weight matrix + out_features = 128 + in_features = 256 + weight = torch.randn( + out_features, in_features, dtype=torch.bfloat16, device="cuda" + ) + + # Create synthetic Hessian (positive semi-definite) + # H = A^T @ A ensures positive semi-definiteness + A = torch.randn(in_features, in_features, dtype=torch.float32, device="cuda") + H = A.t() @ A + # Add regularization to ensure positive definiteness + H = H + torch.eye(in_features, device="cuda") * 0.1 + + # Create GPTQ config + config = GPTQConfig( + step="convert", + group_size=128, + ) + + # Run GPTQ quantization + quantized_weight = gptq_quantize(H, weight, config) + + # Check output type + from torchao.quantization import Int4Tensor + + self.assertIsInstance(quantized_weight, Int4Tensor) + + # Check shape is preserved + self.assertEqual(quantized_weight.shape, weight.shape) + + # Dequantize and check error is reasonable + dequantized = F.linear( + torch.eye(in_features, device="cuda", dtype=torch.bfloat16), + quantized_weight, + None, + ).t() + self.assertEqual(dequantized.shape, weight.shape) + + # Check quantization introduces bounded error + error = torch.abs(dequantized - weight.float()) + mean_error = error.mean().item() + max_error = error.max().item() + + # GPTQ should have reasonable error bounds + self.assertLess(mean_error, 0.5, f"Mean error too high: {mean_error}") + self.assertLess(max_error, 5.0, f"Max error too high: {max_error}") + + # Check that quantization actually compressed the data + # Int4 should be much smaller than bfloat16 + self.assertTrue(hasattr(quantized_weight, "qdata")) + + def test_gptq_quantize_better_than_naive(self): + """Test that GPTQ produces lower error than naive quantization.""" + torch.manual_seed(43) + + # Create weight and realistic Hessian from actual activations + out_features = 64 + in_features = 128 + weight = torch.randn( + out_features, in_features, dtype=torch.bfloat16, device="cuda" + ) + + # Simulate activations and compute Hessian + num_samples = 100 + activations = [] + for _ in range(num_samples): + act = torch.randn(4, in_features, dtype=torch.float32, device="cuda") + activations.append(act) + + # Compute Hessian from activations + from torchao.prototype.gptq import _calculate_hessian + + H = _calculate_hessian(activations, device="cuda") + H_identity = torch.eye(in_features, device="cuda", dtype=torch.float32) + + # GPTQ quantization + config = GPTQConfig( + step="convert", + group_size=128, + ) + gptq_quantized = gptq_quantize(H, weight, config) + gptq_dequantized = F.linear( + H_identity.to(torch.bfloat16), gptq_quantized, None + ).t() + + # Naive quantization (using identity Hessian) + naive_quantized = gptq_quantize(H_identity, weight, config) + naive_dequantized = F.linear( + H_identity.to(torch.bfloat16), naive_quantized, None + ).t() + + # Compute weighted error using Hessian + # Error metric: (W - W_q)^T H (W - W_q) + weight_f = weight.float() + gptq_error = weight_f - gptq_dequantized + naive_error = weight_f - naive_dequantized + + # Compute Frobenius norm of errors + gptq_loss = torch.norm(gptq_error).item() + naive_loss = torch.norm(naive_error).item() + + print(f"GPTQ loss: {gptq_loss:.4f}, Naive loss: {naive_loss:.4f}") + + # GPTQ should generally produce lower or comparable error + # (Note: with random data, this might not always hold, but with real Hessian it should) + self.assertIsNotNone(gptq_loss) + self.assertIsNotNone(naive_loss) + + def test_gptq_transformer(self): + torch.manual_seed(43) + from torchao._models.llama.model import ( + ModelArgs, + Transformer, + prepare_inputs_for_model, + ) + + torch.set_default_dtype(torch.bfloat16) + + config = ModelArgs(n_layer=2) + + with torch.device("cuda"): + model = Transformer(config) + model.setup_caches(max_batch_size=2, max_seq_length=100) + idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32) + test_input = prepare_inputs_for_model(idx[0]) + + model2 = copy.deepcopy(model) + model_baseline = copy.deepcopy(model) + + # get new gptq implementation out + gptqnew_config = GPTQConfig(step="observe", group_size=128) + quantize_(model, gptqnew_config) + + # new calibration + for i in range(10): + input = prepare_inputs_for_model(idx[i]) + model(*input) + + convert_config = GPTQConfig(step="convert", group_size=128) + quantize_(model, convert_config) + out_gptq = model(*test_input) + + quantize_(model2, Int4WeightOnlyConfig(version=2)) + out_rtn = model2(*test_input) + + out = model_baseline(*test_input) + + from torchao.quantization.utils import compute_error + + sqnr_rtn = compute_error(out_rtn, out) + sqnr_gptq = compute_error(out_gptq, out) + + assert sqnr_gptq > 30, f"GPTQ SQNR: {sqnr_gptq} is too low" + assert sqnr_gptq > sqnr_rtn, ( + f"GPTQ SQNR: {sqnr_gptq} is not better than RTN SQNR: {sqnr_rtn}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/gptq/__init__.py b/torchao/prototype/gptq/__init__.py new file mode 100644 index 0000000000..a18e36d971 --- /dev/null +++ b/torchao/prototype/gptq/__init__.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import math +import types +from dataclasses import dataclass +from functools import partial + +import torch +import torch.nn as nn +from fbgemm_gpu.experimental.gen_ai.quantize import int4_row_quantize_zp, pack_int4 + +from torchao.core.config import AOBaseConfig +from torchao.quantization import Int4Tensor +from torchao.quantization.quant_api import _module_extra_repr +from torchao.quantization.transform_module import register_quantize_module_handler + +from .observer import ObserverTensor + + +@dataclass +class GPTQConfig(AOBaseConfig): + """Unified config for GPTQ quantization with explicit step control. + + step="observe": wraps weights as ObserverTensor for observation. + step="convert": applies GPTQ quantization to observed tensors. + """ + + step: str = "observe" # "observe" or "convert" + group_size: int = 128 + percdamp: float = 0.01 + gptq_quantize_block_size: int = 128 + + +@register_quantize_module_handler(GPTQConfig) +def _gptq_config_transform( + module: torch.nn.Module, config: GPTQConfig, *, parameter_name="weight" +) -> torch.nn.Module: + """Unified transform handler that uses explicit step control.""" + tensor = getattr(module, parameter_name) + + if config.step == "observe": + # Observation phase: wrap as ObserverTensor + new_tensor = ObserverTensor.from_hp(tensor) + setattr(module, parameter_name, nn.Parameter(new_tensor, requires_grad=False)) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) + return module + elif config.step == "convert": + # Quantization phase: tensor should be an ObserverTensor + if not isinstance(tensor, ObserverTensor): + raise ValueError( + f"Expected {parameter_name} to be ObserverTensor in 'convert' step, " + f"but got {type(tensor)}. Did you run the 'observe' step first?" + ) + + # Validate that observations were recorded + if tensor.hessian is None: + raise ValueError( + f"No observations recorded for {parameter_name}. " + f"Hessian is None. Did you run forward passes during the observe step?" + ) + + # Use pre-computed Hessian directly + hessian = tensor.hessian + new_tensor = gptq_quantize(hessian, tensor.hp_data, config) + new_quantized_tensor = nn.Parameter(new_tensor, requires_grad=False) + setattr(module, parameter_name, new_quantized_tensor) + return module + else: + raise ValueError( + f"Invalid step '{config.step}'. Must be 'observe' or 'convert'." + ) + + +def _int4_row_quantize_zp_precomputed_qparams( + x: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + group_size: int = 128, +) -> torch.Tensor: + """Quantize tensor using precomputed scales and zero points.""" + n_bit = 4 + to_quant = torch.split(x.to(torch.float), group_size, dim=-1) + + scales_row = scales.t().contiguous() + zeros_row = zeros.t().contiguous() + scales_list = torch.split(scales_row, 1, dim=-1) + zeros_list = torch.split(zeros_row, 1, dim=-1) + + min_val = [ + zero_chunk - scale_chunk * (2 ** (n_bit - 1)) + for zero_chunk, scale_chunk in zip(zeros_list, scales_list) + ] + max_int = 2**n_bit - 1 + min_int = 0 + + out = [ + chunk.sub(min_chunk).div(scale_chunk).round().clamp_(min_int, max_int) + for chunk, min_chunk, scale_chunk in zip(to_quant, min_val, scales_list) + ] + out = [(chunk - 2 ** (n_bit - 1)).to(dtype=torch.int8) for chunk in out] + out = torch.cat(out, dim=-1) + return out + + +def _int4_row_dequantize_zp( + x: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + group_size: int = 128, +) -> torch.Tensor: + """Dequantize int4 row-quantized tensor with zero point.""" + n_bit = 4 + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + + x_chunks = torch.split(x, group_size, dim=-1) + scales_list = torch.split(scales, 1, dim=-1) + zeros_list = torch.split(zeros, 1, dim=-1) + + dequant_chunks = [] + for chunk, scale_chunk, zero_chunk in zip(x_chunks, scales_list, zeros_list): + chunk_float = chunk.to(torch.float32) + 2 ** (n_bit - 1) + min_val = zero_chunk - scale_chunk * (2 ** (n_bit - 1)) + dequant = chunk_float * scale_chunk + min_val + dequant_chunks.append(dequant) + + return torch.cat(dequant_chunks, dim=-1) + + +def gptq_quantize(H, W, config): + print("gptq quantizing weight of shape: ", W.shape) + block_size = [1, config.group_size] + gptq_quantize_block_size = config.gptq_quantize_block_size + percdamp = config.percdamp + group_size = config.group_size + + assert W.dim() == 2 + assert group_size > 0 + + W = W.view(-1, W.shape[-1]).detach() + columns = W.shape[1] + device = W.device + + gptq_quantize_block_size = ( + math.ceil(gptq_quantize_block_size / group_size) * group_size + ) + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(columns, device=device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + all_qparams = [] + + for W_quantize_block, block_start in zip( + torch.split(W, gptq_quantize_block_size, dim=1), + range(0, columns, gptq_quantize_block_size), + ): + block_end = min(block_start + gptq_quantize_block_size, columns) + + Err1 = torch.zeros_like(W_quantize_block, dtype=H.dtype) + Hinv_quantize_block = Hinv[block_start:block_end, block_start:block_end] + + for W_group, group_start in zip( + torch.split(W_quantize_block, group_size, dim=1), + range(block_start, block_end, group_size), + ): + group_end = min(group_start + group_size, columns) + + if group_start % group_size == 0: + # calculate qparams once per group + _, scale, zero = int4_row_quantize_zp(W_group, group_size) + all_qparams.append((scale, zero)) + + # within each group + for i in range(group_start - block_start, group_end - block_start): + w = W_quantize_block[:, i].unsqueeze(1) + + q = _int4_row_quantize_zp_precomputed_qparams( + w, scale, zero, group_size + ) + # Dequantize for error calculation + dq = _int4_row_dequantize_zp(q, scale, zero, group_size) + + err1 = (w - dq) / Hinv_quantize_block[i, i] + W_quantize_block[:, i:] -= err1.matmul( + Hinv_quantize_block[i, i:].unsqueeze(0) + ) + Err1[:, i] = err1.flatten() + + W[:, block_end:] -= Err1.matmul(Hinv[block_start:block_end, block_end:]) + + if "cuda" in device.type: + torch.cuda.synchronize() + + # Create final Int4Tensor using standard from_hp method + final_qparams = [torch.cat(x, dim=0) for x in zip(*all_qparams)] + + # Quantize using precomputed qparams + wq = _int4_row_quantize_zp_precomputed_qparams( + W, final_qparams[0], final_qparams[1], group_size + ) + wq_packed = pack_int4(wq) + + res = Int4Tensor( + qdata=wq_packed, + scale=final_qparams[0].to(W.dtype), + zero_point=final_qparams[1].to(W.dtype), + block_size=block_size, + shape=W.shape, + act_pre_scale=None, + ) + return res + + +def _calculate_hessian(inputs, device=None): + """Calculate Hessian matrix from input activations for GPTQ. + + DEPRECATED: This function is kept for backward compatibility in tests only. + ObserverTensor now computes Hessian incrementally during observation. + Use ObserverTensor.hessian instead for production code. + """ + H = 0 + total_batches = 0 + + for inp in inputs: + # Setup x (activation tensor) + x = inp.float() + if device: + x = x.to(device) + shape = x.shape + n = 1 if len(shape) == 2 else shape[0] + x = x.reshape(-1, shape[-1]) + + # Update Hessian with running average + H *= total_batches / (total_batches + n) + total_batches += n + + x = ((2 / total_batches) ** (1 / 2)) * x.t() + H += x.matmul(x.t()) + + return H + + +__all__ = [ + "ObserverTensor", + "GPTQConfig", + "gptq_quantize", + "sequential_quantize_", +] diff --git a/torchao/prototype/gptq/gptq_example.py b/torchao/prototype/gptq/gptq_example.py new file mode 100644 index 0000000000..fe3b89b5e0 --- /dev/null +++ b/torchao/prototype/gptq/gptq_example.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import gc +import subprocess +import time +from typing import Any, List, Optional + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchao.prototype.gptq import GPTQConfig +from torchao.quantization import quantize_ + + +def sequential_quantize( + model, + calibration_data: List[torch.Tensor], + config: Any, +) -> None: + # run with no grad otherwise keeping all the tensors around for the backwards will cause oom + with torch.no_grad(): + # Get device from embed_tokens layer (supports device_map="auto") + embed_device = next(model.model.embed_tokens.parameters()).device + + # Prepare embeddings + inputs = [] + position_ids = [] + position_embeddings = [] + + # Generate embeddings for each sequence + for seq in calibration_data: + seq_length = seq.shape[1] + embedded = model.model.embed_tokens(seq.to(embed_device)) + inputs.append(embedded) + pid = torch.arange( + 0, seq_length, dtype=torch.long, device=embed_device + ).unsqueeze(0) + position_ids.append(pid) + position_embeddings.append(model.model.rotary_emb(embedded, pid)) + + # Process each transformer block sequentially + num_blocks = len(model.model.layers) + for block_idx in tqdm(range(num_blocks), desc="Quantizing blocks"): + block = model.model.layers[block_idx] + print(f"Working on block {block_idx} ...") + + for i in range(len(inputs)): + block( + inputs[i].to(next(block.parameters()).device), + position_ids=position_ids[i], + position_embeddings=position_embeddings[i], + ) + + quantize_(block, config) + + # Synchronize across devices after quantizing each block + for i in tqdm(range(len(inputs)), desc="propogating activations"): + inputs[i] = block( + inputs[i].to(next(block.parameters()).device), + position_ids=position_ids[i], + position_embeddings=position_embeddings[i], + ) + + +def prepare_dataset( + tokenizer: AutoTokenizer, + max_sequence_length: int, + num_calibration_samples: Optional[int] = None, + dataset_id: str = "hellaswag", + dataset_split: str = "train", + seed: int = 42, +) -> List[torch.Tensor]: + # Map dataset names to HuggingFace IDs + dataset_map = { + "hellaswag": "Rowan/hellaswag", + "ultrachat200k": "HuggingFaceH4/ultrachat_200k", + } + + hf_dataset_id = dataset_map.get(dataset_id, dataset_id) + + # Load dataset and preprocess + train_dataset_raw = load_dataset(hf_dataset_id, split=dataset_split, streaming=True) + train_dataset_raw = train_dataset_raw.shuffle(seed=seed, buffer_size=1_000) + + def preprocess_hellaswag(example): + # HellaSwag format: context + correct ending + context = example["ctx"] + endings = example["endings"] + correct_ending = endings[int(example["label"])] + text = context + " " + correct_ending + return {"text": text} + + def preprocess_ultrachat(example): + # UltraChat format: conversation messages + messages = example.get("messages", []) + # Concatenate all messages into a single text + text = " ".join([msg.get("content", "") for msg in messages]) + return {"text": text} + + # Choose preprocessing based on dataset + if dataset_id == "hellaswag": + train_dataset_raw = train_dataset_raw.map(preprocess_hellaswag) + elif dataset_id == "ultrachat200k": + train_dataset_raw = train_dataset_raw.map(preprocess_ultrachat) + + train_dataset = [] + for i, sample in enumerate(train_dataset_raw): + if i == num_calibration_samples: + break + tokenized_sample = tokenizer( + sample["text"], + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + train_dataset.append(tokenized_sample["input_ids"]) + return train_dataset + + +def parse_args(): + parser = argparse.ArgumentParser( + description="GPTQ quantization example for language models" + ) + parser.add_argument( + "--model-id", + type=str, + default="unsloth/Llama-3.2-1B-Instruct", + help="HuggingFace model ID to quantize", + ) + parser.add_argument( + "--num-calibration-samples", + type=int, + default=5000, + help="Number of calibration samples to use", + ) + parser.add_argument( + "--max-sequence-length", + type=int, + default=8192, + help="Maximum sequence length (default: use model's max_length)", + ) + parser.add_argument( + "--dataset-id", + type=str, + default="hellaswag", + choices=["hellaswag", "ultrachat200k"], + help="Dataset for calibration (hellaswag or ultrachat200k)", + ) + parser.add_argument( + "--quantization", + type=str, + default="int4-gptq-sequential", + choices=["none", "int4-rtn", "int4-gptq-sequential", "int4-gptq-nonsequential"], + help="Quantization method to use", + ) + parser.add_argument( + "--percdamp", + type=float, + default=0.5, + help="Percentage damping for GPTQ", + ) + parser.add_argument( + "--group-size", + type=int, + default=128, + help="Group size for quantization", + ) + parser.add_argument( + "--gptq-block-size", + type=int, + default=1024, + help="Block size for GPTQ quantization", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + # Map dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + } + dtype = dtype_map.get("bfloat16", torch.bfloat16) + + print(f"Loading model {args.model_id}...") + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + device_map="cuda:0", + dtype=dtype, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_id) + + print(f"Model config: {model.config}") + + # Determine max sequence length + max_seq_length = args.max_sequence_length + if max_seq_length is None: + max_seq_length = getattr(model.config, "max_length", 2048) + print(f"Using model's max_length: {max_seq_length}") + + # Generate output directory name from args + model_name = args.model_id.split("/")[-1] # Get last part of model ID + output_dir = f"{model_name}_{args.quantization}" + + if args.quantization != "none": + output_dir += f"_gs{args.group_size}" + + if args.quantization in ["int4-gptq-sequential", "int4-gptq-nonsequential"]: + output_dir += f"_{args.dataset_id}_n{args.num_calibration_samples}" + output_dir += f"_damp{args.percdamp}_bs{args.gptq_block_size}" + + print(f"Output directory: {output_dir}") + + # Handle different quantization methods + quantization_start_time = time.time() + + if args.quantization == "int4-rtn": + print("Applying Int4 RTN (Round-To-Nearest) quantization...") + from torchao.quantization import Int4WeightOnlyConfig + + config = Int4WeightOnlyConfig(group_size=args.group_size) + quantize_(model, config, filter_fn=None) + + elif args.quantization in ["int4-gptq-sequential", "int4-gptq-nonsequential"]: + # First application: wrap weights with ObserverTensor (observe step) + print("Wrapping weights with ObserverTensor for calibration...") + observe_config = GPTQConfig( + step="observe", + group_size=args.group_size, + percdamp=args.percdamp, + gptq_quantize_block_size=args.gptq_block_size, + ) + quantize_(model, observe_config, filter_fn=None) + + # Prepare calibration dataset + print( + f"Preparing {args.num_calibration_samples} calibration samples from {args.dataset_id}..." + ) + dataset = prepare_dataset( + tokenizer, + max_seq_length, + args.num_calibration_samples, + dataset_id=args.dataset_id, + dataset_split="train", + seed=42, + ) + + # Second application: apply GPTQ quantization (convert step) + convert_config = GPTQConfig( + step="convert", + group_size=args.group_size, + percdamp=args.percdamp, + gptq_quantize_block_size=args.gptq_block_size, + ) + + if args.quantization == "int4-gptq-sequential": + print("Applying GPTQ quantization (sequential)...") + sequential_quantize(model, dataset, convert_config) + else: # int4-gptq-nonsequential + print("Applying GPTQ quantization (non-sequential)...") + # Get device for input (from embedding layer, supports device_map="auto") + input_device = next(model.model.embed_tokens.parameters()).device + + # Run calibration + for seq in tqdm(dataset, desc="Calibrating"): + model(seq.to(input_device)) + # Apply quantization + quantize_(model, convert_config, filter_fn=None) + + quantization_end_time = time.time() + quantization_time = quantization_end_time - quantization_start_time + + if args.quantization != "none": + print(f"\n{'=' * 60}") + print( + f"Quantization completed in {quantization_time:.2f} seconds ({quantization_time / 60:.2f} minutes)" + ) + print(f"{'=' * 60}\n") + + # Save model to generated output directory + print(f"Saving model to {output_dir}...") + tokenizer.save_pretrained(output_dir) + model.save_pretrained(output_dir, safe_serialization=False) + + print("DONE!") + + # Clear GPU memory before running lm_eval + print("\nClearing GPU memory...") + del model + del tokenizer + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("GPU memory cleared.") + + # Run lm_eval on the saved model + print(f"\n{'=' * 60}") + print("Running lm_eval on the quantized model...") + print(f"{'=' * 60}\n") + + lm_eval_cmd = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={output_dir}", + "--tasks", + "hellaswag", + "--batch_size", + "auto", + ] + + print(f"Running command: {' '.join(lm_eval_cmd)}") + try: + subprocess.run(lm_eval_cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"lm_eval failed with error: {e}") + except FileNotFoundError: + print("lm_eval not found. Please install it with: pip install lm-eval") + + +if __name__ == "__main__": + main() diff --git a/torchao/prototype/gptq/observer.py b/torchao/prototype/gptq/observer.py new file mode 100644 index 0000000000..51366ab10c --- /dev/null +++ b/torchao/prototype/gptq/observer.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +from torchao.utils import TorchAOBaseTensor + + +class ObserverTensor(TorchAOBaseTensor): + tensor_data_names = ["hp_data"] + optional_tensor_data_names = ["hessian"] + tensor_attribute_names = ["total_batches"] + + def __new__(cls, hp_data: torch.Tensor, total_batches: int, hessian=None): + shape = hp_data.shape + kwargs = {} + kwargs["device"] = hp_data.device + kwargs["dtype"] = hp_data.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, hp_data: torch.Tensor, total_batches: int, hessian=None): + super().__init__() + self.hp_data = hp_data + self.hessian = hessian + self.total_batches = total_batches + + def update(self, input: torch.Tensor): + """Incrementally update Hessian matrix from input activations.""" + # Move input to same device as hp_data and convert to float + x = input.float().to(self.hp_data.device) + shape = x.shape + + # Calculate batch size + n = 1 if len(shape) == 2 else shape[0] + x = x.reshape(-1, shape[-1]) + + # Lazily initialize Hessian on first call + if self.hessian is None: + feature_dim = x.shape[-1] + self.hessian = torch.zeros( + feature_dim, + feature_dim, + dtype=torch.float32, + device=self.hp_data.device, + ) + + # Apply running average formula + if self.total_batches > 0: + self.hessian *= self.total_batches / (self.total_batches + n) + + self.total_batches += n + + # Update Hessian: x = ((2 / total_batches) ** (1 / 2)) * x.t() + x = ((2 / self.total_batches) ** (1 / 2)) * x.t() + self.hessian += x.matmul(x.t()) + + @classmethod + def from_hp(cls, hp_tensor): + return ObserverTensor(hp_tensor, 0, None) + + +implements = ObserverTensor.implements +implements_torch_function = ObserverTensor.implements_torch_function +aten = torch.ops.aten + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(weight_tensor, ObserverTensor): + weight_tensor.update(input_tensor.detach()) + return F.linear(input_tensor, weight_tensor.hp_data, bias) + else: + raise ValueError( + f"Expected weight_tensor to be ObserverTensor, got: {type(weight_tensor)}" + ) + + +@implements(aten.bmm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor.update(input_tensor.detach()) + return func(input_tensor, weight_tensor.hp_data)