Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
from torchao.utils import (
TorchAOBaseTensor,
fill_defaults,
)

__all__ = [
Expand Down Expand Up @@ -242,6 +243,95 @@ def _(func, types, args, kwargs):
return y.to(orig_dtype)


@implements(aten.slice.Tensor)
def _(func, _types, args, _kwargs):
"""Slice operation for CPU int4 opaque tensor"""
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
cur_shape = self.shape

assert len(cur_shape) == 2
assert self.qdata.dim() == 2
# qdata has shape (N, K/2) - packed in last dimension
# scale_and_zero has shape (K/group_size, N, 2)

data_len = cur_shape[dim]
assert dim in [
0,
1,
], (
f"Int4OpaqueTensor slice: attempting to run {func}, with "
f"dim={dim}, that is not supported"
)

if dim == 0:
# Slicing N dimension
qdata_len = self.qdata.shape[0] # N
sz_len = self.scale_and_zero.shape[1] # N

if qdata_len == 0 or sz_len == 0:
return Int4OpaqueTensor(
self.qdata,
self.scale_and_zero,
self.block_size,
self.shape,
act_pre_scale=self.act_pre_scale,
)

qdata_ratio = data_len / qdata_len
start_qdata = int(start / qdata_ratio)
end_qdata = int(end / qdata_ratio)

sz_ratio = data_len / sz_len
start_sz = int(start / sz_ratio)
end_sz = int(end / sz_ratio)

qdata = aten.slice(self.qdata, 0, start_qdata, end_qdata, step)
scale_and_zero = aten.slice(
self.scale_and_zero, 1, start_sz, end_sz, step
)
else:
# Slicing K dimension (dim == 1)
qdata_len = self.qdata.shape[1] * 2 # K/2 packed, so multiply by 2
# K/group_size * group_size = K
sz_len = self.scale_and_zero.shape[0] * self.block_size[1]

if qdata_len == 0 or sz_len == 0:
return Int4OpaqueTensor(
self.qdata,
self.scale_and_zero,
self.block_size,
self.shape,
act_pre_scale=self.act_pre_scale,
)

qdata_ratio = data_len / qdata_len
start_qdata = int(start / qdata_ratio)
end_qdata = int(end / qdata_ratio)

sz_ratio = data_len / sz_len
start_sz = int(start / sz_ratio)
end_sz = int(end / sz_ratio)

qdata = aten.slice(self.qdata, 1, start_qdata, end_qdata, step)
scale_and_zero = aten.slice(
self.scale_and_zero, 0, start_sz, end_sz, step
)

# Calculate new shape after slicing
new_shape = list(self.shape)
new_shape[dim] = end - start

block_size = list(self.block_size)
block_size[dim] = min(block_size[dim], new_shape[dim])

return Int4OpaqueTensor(
qdata,
scale_and_zero,
block_size,
new_shape,
act_pre_scale=self.act_pre_scale,
)

Int4OpaqueTensor.__module__ = "torchao.prototype.int4_opaque_tensor"

# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True`
Expand Down