Skip to content

Conversation

@jcaip
Copy link
Contributor

@jcaip jcaip commented Dec 19, 2025

This PR adds GPTQ support to torchao.prototype.gptq.

It is exposed via a new config, GPTQConfig, which can have two steps, "observe" and "convert".

When quantize_(model, GPTQConfig(step="observe")) is run, observer tensors are attached to the weight tensors, which keep track of linear / torch.bmm ops and updates the hessian matrix based on the observed inputs.

When quantize_(model, GPTQConfig(step="convert")) is run, we will find any observer tensors, take the Hessian and do int4 GPTQ quantization to find the weights. The core of this function is in gptq_quantize.

Currently only int4 is hardcoded, but if we enable dequantization and the ability to create a tensor with existing qparams, we should be able to do this for any config.

Also included is an example script, gptq_example.py that does sequential / nonsequential quantization on helllaswag for a simple example.

- Add GPTQ quantization algorithm implementation in torchao/prototype/gptq
- Add ObserverTensor for activation tracking during calibration
- Add unified GPTQConfig that handles both observation and quantization phases
- Add gptq_quantize function for weight quantization with Hessian
- Add comprehensive test suite for GPTQ and ObserverTensor
- Add example script demonstrating GPTQ quantization workflow
- All Int4 helper functions are self-contained in gptq module
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3517

Note: Links to docs will display an error until the docs builds have been completed.

❌ 12 New Failures

As of commit 5b99ce2 with merge base c4273fe (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 19, 2025
@jcaip jcaip added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) accuracy Accuracy related labels Dec 19, 2025
@jcaip jcaip marked this pull request as ready for review December 19, 2025 19:29
@jcaip jcaip requested review from jerryzh168 and vkuzo December 19, 2025 19:29
@@ -0,0 +1,453 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put in test/prototype/gptq folder to match codebase convention

from .observer import ObserverTensor


@dataclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put logic in some_file.py instead of __init__.py

Comment on lines +160 to +170
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(columns, device=device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add some comments linking to the right page of the paper


all_qparams = []

for W_quantize_block, block_start in zip(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add comments as needed to explain how this code is following the algorithm in the paper

def _calculate_hessian(inputs, device=None):
"""Calculate Hessian matrix from input activations for GPTQ.

DEPRECATED: This function is kept for backward compatibility in tests only.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward compatibility with what?

from torchao.utils import TorchAOBaseTensor


class ObserverTensor(TorchAOBaseTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic here is specific to GPTQ, can we ensure the name specifies that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

accuracy Accuracy related CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants