Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions torchao/quantization/linear_quant_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ def __init__(
self,
in_features: int,
out_features: int,
# TODO: remove dtype field, not used
bias=False,
device=None,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
precision: torch.dtype = torch.bfloat16,
Expand All @@ -110,9 +108,6 @@ def __init__(
self.precision = precision
self.scales_precision = scales_precision

if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")

assert out_features % 8 == 0, "require out_features % 8 == 0"
assert in_features % (inner_k_tiles * 16) == 0, (
"require in_features % (innerKTiles * 16) == 0"
Expand All @@ -125,7 +120,7 @@ def __init__(
out_features,
in_features // 2,
),
dtype=torch.uint8,
precision=torch.uint8,
device=device,
),
)
Expand All @@ -139,16 +134,16 @@ def __init__(
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
precision=torch.int32,
device=device,
),
)
self.dtype = dtype
self.precision = precision
self.register_buffer(
"scales_and_zeros",
torch.zeros(
(in_features // groupsize, out_features, 2),
dtype=self.scales_precision,
precision=self.scales_precision,
device=device,
),
)
Expand Down Expand Up @@ -410,8 +405,6 @@ def __init__(
out_features: int,
bias=True,
device=None,
# TODO: remove this field, not used
dtype=None,
groupsize: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
Expand All @@ -434,26 +427,23 @@ def __init__(
# that his module represents.
self.precision = precision

if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")

# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.zeros((out_features, in_features), dtype=torch.int8),
torch.zeros((out_features, in_features), precision=torch.int8),
)
self.register_buffer(
"scales",
torch.zeros(
(out_features, in_features // groupsize),
dtype=scales_precision,
precision=scales_precision,
),
)
self.register_buffer(
"zeros",
torch.zeros(
(out_features, in_features // groupsize),
dtype=scales_precision,
precision=scales_precision,
),
)

Expand Down
Loading