-
Notifications
You must be signed in to change notification settings - Fork 387
Add GPTQ to prototype #3517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add GPTQ to prototype #3517
Conversation
- Add GPTQ quantization algorithm implementation in torchao/prototype/gptq - Add ObserverTensor for activation tracking during calibration - Add unified GPTQConfig that handles both observation and quantization phases - Add gptq_quantize function for weight quantization with Hessian - Add comprehensive test suite for GPTQ and ObserverTensor - Add example script demonstrating GPTQ quantization workflow - All Int4 helper functions are self-contained in gptq module
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3517
Note: Links to docs will display an error until the docs builds have been completed. ❌ 12 New FailuresAs of commit 5b99ce2 with merge base c4273fe ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| @@ -0,0 +1,453 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: put in test/prototype/gptq folder to match codebase convention
| from .observer import ObserverTensor | ||
|
|
||
|
|
||
| @dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: put logic in some_file.py instead of __init__.py
| dead = torch.diag(H) == 0 | ||
| H[dead, dead] = 1 | ||
| W[:, dead] = 0 | ||
|
|
||
| damp = percdamp * torch.mean(torch.diag(H)) | ||
| diag = torch.arange(columns, device=device) | ||
| H[diag, diag] += damp | ||
| H = torch.linalg.cholesky(H) | ||
| H = torch.cholesky_inverse(H) | ||
| H = torch.linalg.cholesky(H, upper=True) | ||
| Hinv = H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add some comments linking to the right page of the paper
|
|
||
| all_qparams = [] | ||
|
|
||
| for W_quantize_block, block_start in zip( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add comments as needed to explain how this code is following the algorithm in the paper
| def _calculate_hessian(inputs, device=None): | ||
| """Calculate Hessian matrix from input activations for GPTQ. | ||
|
|
||
| DEPRECATED: This function is kept for backward compatibility in tests only. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
backward compatibility with what?
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
|
|
||
| class ObserverTensor(TorchAOBaseTensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the logic here is specific to GPTQ, can we ensure the name specifies that
This PR adds GPTQ support to
torchao.prototype.gptq.It is exposed via a new config, GPTQConfig, which can have two steps, "observe" and "convert".
When
quantize_(model, GPTQConfig(step="observe"))is run, observer tensors are attached to the weight tensors, which keep track of linear / torch.bmm ops and updates the hessian matrix based on the observed inputs.When
quantize_(model, GPTQConfig(step="convert"))is run, we will find any observer tensors, take the Hessian and do int4 GPTQ quantization to find the weights. The core of this function is ingptq_quantize.Currently only int4 is hardcoded, but if we enable dequantization and the ability to create a tensor with existing qparams, we should be able to do this for any config.
Also included is an example script,
gptq_example.pythat does sequential / nonsequential quantization on helllaswag for a simple example.