diff --git a/torchao/quantization/linear_quant_modules.py b/torchao/quantization/linear_quant_modules.py index de6755a55d..98cf68a8d2 100644 --- a/torchao/quantization/linear_quant_modules.py +++ b/torchao/quantization/linear_quant_modules.py @@ -86,10 +86,8 @@ def __init__( self, in_features: int, out_features: int, - # TODO: remove dtype field, not used bias=False, device=None, - dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, precision: torch.dtype = torch.bfloat16, @@ -110,9 +108,6 @@ def __init__( self.precision = precision self.scales_precision = scales_precision - if dtype is not None: - raise ValueError("Please specify 'precision' instead of 'dtype'") - assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, ( "require in_features % (innerKTiles * 16) == 0" @@ -125,7 +120,7 @@ def __init__( out_features, in_features // 2, ), - dtype=torch.uint8, + precision=torch.uint8, device=device, ), ) @@ -139,16 +134,16 @@ def __init__( 32, inner_k_tiles // 2, ), - dtype=torch.int32, + precision=torch.int32, device=device, ), ) - self.dtype = dtype + self.precision = precision self.register_buffer( "scales_and_zeros", torch.zeros( (in_features // groupsize, out_features, 2), - dtype=self.scales_precision, + precision=self.scales_precision, device=device, ), ) @@ -410,8 +405,6 @@ def __init__( out_features: int, bias=True, device=None, - # TODO: remove this field, not used - dtype=None, groupsize: int = 256, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, @@ -434,26 +427,23 @@ def __init__( # that his module represents. self.precision = precision - if dtype is not None: - raise ValueError("Please specify 'precision' instead of 'dtype'") - # currently storing unpacked int8 weights self.register_buffer( "weight", - torch.zeros((out_features, in_features), dtype=torch.int8), + torch.zeros((out_features, in_features), precision=torch.int8), ) self.register_buffer( "scales", torch.zeros( (out_features, in_features // groupsize), - dtype=scales_precision, + precision=scales_precision, ), ) self.register_buffer( "zeros", torch.zeros( (out_features, in_features // groupsize), - dtype=scales_precision, + precision=scales_precision, ), )