diff --git a/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py index 976f219167..d860fd5469 100644 --- a/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py +++ b/torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py @@ -22,6 +22,7 @@ from torchao.quantization.utils import pack_tinygemm_scales_and_zeros from torchao.utils import ( TorchAOBaseTensor, + fill_defaults, ) __all__ = [ @@ -242,6 +243,95 @@ def _(func, types, args, kwargs): return y.to(orig_dtype) +@implements(aten.slice.Tensor) +def _(func, _types, args, _kwargs): + """Slice operation for CPU int4 opaque tensor""" + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + cur_shape = self.shape + + assert len(cur_shape) == 2 + assert self.qdata.dim() == 2 + # qdata has shape (N, K/2) - packed in last dimension + # scale_and_zero has shape (K/group_size, N, 2) + + data_len = cur_shape[dim] + assert dim in [ + 0, + 1, + ], ( + f"Int4OpaqueTensor slice: attempting to run {func}, with " + f"dim={dim}, that is not supported" + ) + + if dim == 0: + # Slicing N dimension + qdata_len = self.qdata.shape[0] # N + sz_len = self.scale_and_zero.shape[1] # N + + if qdata_len == 0 or sz_len == 0: + return Int4OpaqueTensor( + self.qdata, + self.scale_and_zero, + self.block_size, + self.shape, + act_pre_scale=self.act_pre_scale, + ) + + qdata_ratio = data_len / qdata_len + start_qdata = int(start / qdata_ratio) + end_qdata = int(end / qdata_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice(self.qdata, 0, start_qdata, end_qdata, step) + scale_and_zero = aten.slice( + self.scale_and_zero, 1, start_sz, end_sz, step + ) + else: + # Slicing K dimension (dim == 1) + qdata_len = self.qdata.shape[1] * 2 # K/2 packed, so multiply by 2 + # K/group_size * group_size = K + sz_len = self.scale_and_zero.shape[0] * self.block_size[1] + + if qdata_len == 0 or sz_len == 0: + return Int4OpaqueTensor( + self.qdata, + self.scale_and_zero, + self.block_size, + self.shape, + act_pre_scale=self.act_pre_scale, + ) + + qdata_ratio = data_len / qdata_len + start_qdata = int(start / qdata_ratio) + end_qdata = int(end / qdata_ratio) + + sz_ratio = data_len / sz_len + start_sz = int(start / sz_ratio) + end_sz = int(end / sz_ratio) + + qdata = aten.slice(self.qdata, 1, start_qdata, end_qdata, step) + scale_and_zero = aten.slice( + self.scale_and_zero, 0, start_sz, end_sz, step + ) + + # Calculate new shape after slicing + new_shape = list(self.shape) + new_shape[dim] = end - start + + block_size = list(self.block_size) + block_size[dim] = min(block_size[dim], new_shape[dim]) + + return Int4OpaqueTensor( + qdata, + scale_and_zero, + block_size, + new_shape, + act_pre_scale=self.act_pre_scale, + ) + Int4OpaqueTensor.__module__ = "torchao.prototype.int4_opaque_tensor" # Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True`