diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 7f276860b2..8336950d55 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -18,12 +18,11 @@ from torch.testing import FileCheck from torch.testing._internal import common_utils -from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale +from torchao.dtypes.floatx.float8_layout import preprocess_scale from torchao.float8.float8_utils import compute_error from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, - Float8WeightOnlyConfig, quantize_, ) from torchao.quantization.granularity import ( @@ -42,7 +41,6 @@ get_current_accelerator_device, is_sm_at_least_89, is_sm_at_least_90, - is_sm_version, ) random.seed(0) @@ -68,7 +66,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase): not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) - @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) + @common_utils.parametrize("mode", ["static"]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) # Inputs are (M,..), K, N @@ -110,12 +108,6 @@ def test_fp8_linear_variants( scale_dtype=torch.float32, ) mode_map = { - "dynamic": partial( - Float8DynamicActivationFloat8WeightConfig, - granularity=granularity, - version=1, - ), - "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( Float8StaticActivationFloat8WeightConfig, scale=scale, @@ -194,18 +186,12 @@ def test_per_row_with_float32(self): @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" ) - @common_utils.parametrize("mode", ["dynamic", "weight-only", "static"]) + @common_utils.parametrize("mode", ["static"]) def test_serialization(self, mode: str): # Create and quantize the model model = ToyLinearModel(16, 32).to(device=_DEVICE) mode_map = { - "dynamic": partial( - Float8DynamicActivationFloat8WeightConfig, - granularity=PerTensor(), - version=1, - ), - "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( Float8StaticActivationFloat8WeightConfig, scale=torch.tensor(1.0, dtype=torch.float32, device=_DEVICE), @@ -265,97 +251,6 @@ def test_serialization(self, mode: str): original_layer.weight.scale, new_layer.weight.scale ), f"Scales do not match for {layer_name}" - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - def test_fp8_weight_dimension_warning(self): - # Create model with incompatible dimensions (not multiples of 16) - model = ToyLinearModel(10, 25).to(_DEVICE) # 10x25 and 25x10 weights - - # Set up logging capture - with self.assertLogs("torchao.quantization.utils", level="INFO") as log_context: - quantize_( - model, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), version=1 - ), - ) - print(model) - - # Verify warning messages for both layers - expected_messages = [ - "Skipping float8 quantization: weight shape torch.Size([25, 10])", - "Skipping float8 quantization: weight shape torch.Size([10, 25])", - ] - # Check that we got warnings for both incompatible layers - warning_count = sum( - 1 for msg in log_context.output if "Skipping float8 quantization" in msg - ) - self.assertEqual(warning_count, 2, "Expected warnings for both linear layers") - - # Check warning message content - for expected in expected_messages: - self.assertTrue( - any(expected in msg for msg in log_context.output), - f"Expected warning message containing: {expected}", - ) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - @common_utils.parametrize( - "in_features,out_features", [(512, 1024), (256, 768), (1024, 512)] - ) - @common_utils.parametrize( - "leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)] - ) # fmt: skip - @common_utils.parametrize("bias", [True, False]) - def test_mm_float8dq_per_row( - self, in_features, out_features, leading_shape, bias: bool - ): - device = _DEVICE - dtype = torch.bfloat16 - input_shape = leading_shape + (in_features,) - - ref_linear = ( - torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype) - ) - test_linear = copy.deepcopy(ref_linear) - quantize_( - test_linear, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), - ) - - quant_weight = test_linear.weight - - self.assertTrue(hasattr(quant_weight, "original_weight_tensor")) - weight_impl = quant_weight.original_weight_tensor.tensor_impl - - self.assertTrue(hasattr(weight_impl, "float8_data")) - self.assertTrue(hasattr(weight_impl, "scale")) - self.assertFalse(weight_impl.transposed) - - # Verify scale shape for row-wise quantization - expected_scale_shape = (out_features, 1) - actual_scale_shape = weight_impl.scale.shape - self.assertEqual(actual_scale_shape, expected_scale_shape) - - self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features)) - - input_tensor = torch.randn(*input_shape, device=device, dtype=dtype) - - with torch.no_grad(): - ref_output = ref_linear(input_tensor) - quant_output = torch.nn.functional.linear(input_tensor, quant_weight) - - expected_output_shape = input_tensor.shape[:-1] + (out_features,) - self.assertEqual(quant_output.shape, expected_output_shape) - - error = compute_error(ref_output, quant_output) - assert error > 20, f"Quantization error is too high got a SQNR of {error}" - @unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available") @unittest.skipIf( _DEVICE == "cuda" and not is_sm_at_least_89(), @@ -467,233 +362,6 @@ def test_dequantize_affine_float8_scale_broadcasting(self): # Verify shapes match self.assertEqual(dequantized.shape, input_tensor.shape) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_float8_tensor_slicing_basic(self, granularity): - """Test basic slicing operations on Float8 tensors""" - device = _DEVICE - dtype = torch.bfloat16 - - # Create and quantize a model - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) - quantize_( - model, - Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, version=1 - ), - ) - - weight_impl = model.weight.original_weight_tensor.tensor_impl - - # Test dimension 0 slicing (rows) - sliced_0 = weight_impl[10:20] - self.assertEqual(sliced_0.shape, (10, 64)) - - # Test dimension 1 slicing (columns) - sliced_1 = weight_impl[:, 20:40] - self.assertEqual(sliced_1.shape, (32, 20)) - - # Test combined slicing - sliced_both = weight_impl[5:15, 10:30] - self.assertEqual(sliced_both.shape, (10, 20)) - - # Verify the sliced tensors are still Float8 tensors - self.assertTrue(isinstance(sliced_0, Float8AQTTensorImpl)) - self.assertTrue(isinstance(sliced_1, Float8AQTTensorImpl)) - self.assertTrue(isinstance(sliced_both, Float8AQTTensorImpl)) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - def test_float8_tensor_slicing_per_tensor(self): - """Test slicing with per-tensor quantization (scale should not change)""" - device = _DEVICE - dtype = torch.bfloat16 - - # Create and quantize with per-tensor granularity - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) - quantize_( - model, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), version=1 - ), - ) - - original_weight = model.weight - original_impl = original_weight.original_weight_tensor.tensor_impl - original_scale = original_impl.scale - - # Test slicing - sliced_weight = original_weight[10:20, 20:40] - sliced_impl = sliced_weight.original_weight_tensor.tensor_impl - - # For per-tensor quantization, scale should be identical - self.assertTrue(torch.equal(original_scale, sliced_impl.scale)) - self.assertEqual(sliced_impl.scale.numel(), 1) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - @unittest.skipIf( - not is_sm_at_least_90(), - "Per-row quantization requires compute capability >= 9.0", - ) - def test_float8_tensor_slicing_per_row(self): - """Test slicing with per-row quantization (scale should be sliced appropriately)""" - device = _DEVICE - dtype = torch.bfloat16 - - # Create and quantize with per-row granularity - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) - quantize_( - model, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), - ) - - original_weight = model.weight # Shape: (32, 64) - original_impl = original_weight.original_weight_tensor.tensor_impl - original_scale = original_impl.scale # Shape: (32, 1) - - # Test row slicing (dimension 0) - sliced_rows = original_weight[10:20] # Shape: (10, 64) - sliced_impl = sliced_rows.original_weight_tensor.tensor_impl - - # Scale should be sliced to match the rows - expected_scale_shape = (10, 1) - self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) - - # Verify the scale values are correct (should be subset of original) - self.assertTrue(torch.equal(sliced_impl.scale, original_scale[10:20])) - - # Test column slicing (dimension 1) - scale should not change for per-row - sliced_cols = original_weight[:, 20:40] # Shape: (32, 20) - sliced_cols_impl = sliced_cols.original_weight_tensor.tensor_impl - - # Scale shape should remain the same since we're not changing rows - self.assertEqual(sliced_cols_impl.scale.shape, (32, 1)) - self.assertTrue(torch.equal(sliced_cols_impl.scale, original_scale)) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - def test_float8_tensor_slicing_edge_cases(self): - """Test edge cases in slicing""" - device = _DEVICE - dtype = torch.bfloat16 - - # Create and quantize a model - model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) - quantize_( - model, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), version=1 - ), - ) - - original_weight = model.weight - - # Test empty slice - empty_slice = original_weight[0:0] - self.assertEqual(empty_slice.shape, (0, 64)) - - # Test single element slice - single_row = original_weight[0:1] - self.assertEqual(single_row.shape, (1, 64)) - - # Test out of bounds (should be handled by PyTorch) - large_slice = original_weight[:100] # More than available rows - self.assertEqual(large_slice.shape, (32, 64)) # Should clamp to available - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @unittest.skipIf( - not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" - ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - @unittest.skipIf( - is_sm_version(8, 9), - "TODO: AssertionError: tensor(-2.1562, device='cuda:0', dtype=torch.bfloat16) not greater than 15", - ) - def test_float8_tensor_slicing_functional_correctness(self, granularity): - """Test that sliced tensors produce correct results in computations""" - device = _DEVICE - dtype = torch.bfloat16 - - # Create reference and quantized models with dimensions that are multiples of 16 - ref_model = ( - torch.nn.Linear(64, 48, bias=False).to(device).to(dtype) - ) # 48 is divisible by 16 - quant_model = copy.deepcopy(ref_model) - quantize_( - quant_model, - Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, version=1 - ), - ) - - # Create input with batch size that works well with slicing - input_tensor = torch.randn(8, 64, device=device, dtype=dtype) - - ref_weight_slice = ref_model.weight[0:16, 0:32] - quant_weight_slice = quant_model.weight[0:16, 0:32] - - # Verify that the sliced weights maintain Float8 properties - self.assertTrue(hasattr(quant_weight_slice, "original_weight_tensor")) - sliced_impl = quant_weight_slice.original_weight_tensor.tensor_impl - self.assertTrue(isinstance(sliced_impl, Float8AQTTensorImpl)) - - # Verify sliced weight shapes - self.assertEqual(sliced_impl.float8_data.shape, (16, 32)) - - # Get original quantized weight implementation for scale comparison - original_quant_impl = quant_model.weight.original_weight_tensor.tensor_impl - - # Verify scale properties based on granularity - if isinstance(granularity, PerTensor): - # Per-tensor: scale should be identical to original (scalar) - self.assertEqual(sliced_impl.scale.numel(), 1) - self.assertTrue(torch.equal(sliced_impl.scale, original_quant_impl.scale)) - else: # PerRow - # Per-row: scale should be sliced to match the selected rows (0:16) - expected_scale_shape = (16, 1) - self.assertEqual(sliced_impl.scale.shape, expected_scale_shape) - # Verify the scale values are the correct slice from the original - self.assertTrue( - torch.equal(sliced_impl.scale, original_quant_impl.scale[0:16]) - ) - - # Verify that sliced quantized data matches the correct slice from original - original_float8_data_slice = original_quant_impl.float8_data[0:16, 0:32] - self.assertTrue( - torch.equal(sliced_impl.float8_data, original_float8_data_slice) - ) - - # Verify that sliced weights can be converted back to float with correct values - sliced_float_weight = quant_weight_slice.to(dtype) - self.assertEqual(sliced_float_weight.shape, (16, 32)) - self.assertEqual(sliced_float_weight.dtype, dtype) - - input_slice = input_tensor[:, 0:32] # (8, 32) to match sliced weight - - # Compute with sliced weights - with torch.no_grad(): - ref_output = torch.nn.functional.linear(input_slice, ref_weight_slice) - quant_output = torch.nn.functional.linear(input_slice, quant_weight_slice) - - # Verify shapes - expected_shape = (8, 16) # batch_size x out_features_sliced - self.assertEqual(ref_output.shape, expected_shape) - self.assertEqual(quant_output.shape, expected_shape) - - # Verify reasonable quantization error - error = compute_error(ref_output, quant_output) - self.assertGreater(error, 15, f"Quantization SQNR too low: {error}") - def test_preprocess_scale_3d_reshape(self): """Test that preprocess_scale correctly handles 3D scale tensors""" device = "cpu" # Use CPU for basic functionality test @@ -787,8 +455,7 @@ def test_quantize_dequantize_fp8_inductor(self, float8_dtype, hp_dtype): not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0" ) @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - @common_utils.parametrize("float8_config_version", [1, 2]) - def test_expected_kernels_on_gpu(self, granularity, float8_config_version): + def test_expected_kernels_on_gpu(self, granularity): """ Verify that float8 quantization + torch.compile results in the expected number of kernels in the GPU trace. @@ -799,17 +466,11 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version): m = torch.nn.Sequential( torch.nn.Linear(K, N, device=_DEVICE, dtype=torch.bfloat16) ) - if float8_config_version == 1: - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, version=1 - ) - else: - assert float8_config_version == 2 - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, - version=2, - kernel_preference=KernelPreference.TORCH, - ) + config = Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, + version=2, + kernel_preference=KernelPreference.TORCH, + ) quantize_( m, config, diff --git a/test/integration/test_load_and_run_checkpoint.py b/test/integration/test_load_and_run_checkpoint.py index 6bdee4a1b8..36fad8d9a3 100644 --- a/test/integration/test_load_and_run_checkpoint.py +++ b/test/integration/test_load_and_run_checkpoint.py @@ -29,12 +29,6 @@ _HIGH_PRECISION_MODEL = "facebook/opt-125m" _DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [ - # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev - ( - "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", - 1, - "Float8DynamicActivationFloat8WeightConfig", - ), # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev ( "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev", @@ -56,12 +50,6 @@ ] _DEPRECATED_MODEL_INFO = [ - # model card: https://huggingface.co/torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev - ( - "torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", - 1, - "Float8DynamicActivationFloat8WeightConfig", - ), # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev ( "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9fa6765ff2..1b6da29ce5 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1461,7 +1461,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) + version (int): the version of the config, version 1 is deprecated, version 2 is using Float8Tensor (default) Note: The actual matmul will be computed in original precision of the weight tensor. @@ -1476,26 +1476,11 @@ def __post_init__(self): def _float8_weight_only_quant_tensor(weight, config): - if config.version == 1: - warnings.warn( - "Config Deprecation: version 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" - ) - from torchao.dtypes import to_affine_quantized_floatx - - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) - new_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=config.weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) - else: - assert config.version == 2, f"Unexpected version: {config.version}" - weight_dtype = config.weight_dtype - new_weight = Float8Tensor.from_hp( - weight, float8_dtype=weight_dtype, granularity=PerRow() - ) + assert config.version == 2, f"Unexpected version: {config.version}" + weight_dtype = config.weight_dtype + new_weight = Float8Tensor.from_hp( + weight, float8_dtype=weight_dtype, granularity=PerRow() + ) return new_weight @@ -1596,7 +1581,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) + version (int): the version of the config, version 1 is deprecated, version 2 is using Float8Tensor (default) """ @@ -1672,49 +1657,23 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): "PerRow quantization only works for bfloat16 precision input weight" ) - if config.version == 1: - warnings.warn( - "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" - ) - - block_size = get_block_size(weight.shape[-2:], weight_granularity) - if weight.dim() == 3: - block_size = tuple([1] + list(block_size)) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) - - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) - else: - assert config.version == 2, f"Unexpected version: {config.version}" - act_quant_kwargs = QuantizeTensorToFloat8Kwargs( - activation_dtype, - activation_granularity, - hp_value_lb=activation_value_lb, - hp_value_ub=activation_value_ub, - kernel_preference=kernel_preference, - ) + assert config.version == 2, f"Unexpected version: {config.version}" + act_quant_kwargs = QuantizeTensorToFloat8Kwargs( + activation_dtype, + activation_granularity, + hp_value_lb=activation_value_lb, + hp_value_ub=activation_value_ub, + kernel_preference=kernel_preference, + ) - quantized_weight = Float8Tensor.from_hp( - weight, - float8_dtype=weight_dtype, - granularity=weight_granularity, - mm_config=mm_config, - kernel_preference=kernel_preference, - act_quant_kwargs=act_quant_kwargs, - ) + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) return quantized_weight