diff --git a/test/test_ops.py b/test/test_ops.py index 240488c637..b21e4c6dd8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -18,7 +18,6 @@ from torch.testing._internal.optests import opcheck import torchao -from torchao.dtypes.floatx import from_scaled_tc_floatx from torchao.quantization.marlin_qqq import ( marlin_qqq_workspace, pack_to_marlin_qqq, @@ -56,72 +55,6 @@ class TestOps(TestCase): - def _create_floatx_inputs( - self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype - ): - # Randomly initialize each byte - nbits = 1 + ebits + mbits - floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) - scale = torch.rand(OC).to(dtype) + 0.5 - fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 - return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") - @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - @parametrize("dtype", [torch.half, torch.bfloat16]) - def test_quant_llm_linear(self, ebits, mbits, dtype): - BS = 2 - OC = 256 - IC = 256 - splitK = 1 - floatx_weight, scale, fp16_act = self._create_floatx_inputs( - ebits, mbits, BS, OC, IC, "cuda", dtype - ) - - # smoke test - torchao.ops.quant_llm_linear( - ebits, mbits, fp16_act, floatx_weight, scale, splitK - ) - - # comprehensive testing - test_utils = [ - "test_schema", - "test_autograd_registration", - "test_faketensor", - "test_aot_dispatch_dynamic", - ] - opcheck( - torch.ops.torchao.quant_llm_linear, - (ebits, mbits, fp16_act, floatx_weight, scale, splitK), - test_utils=test_utils, - ) - - @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") - @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - @parametrize("ebits,mbits", [(3, 2), (2, 2)]) - @parametrize("dtype", [torch.half, torch.bfloat16]) - def test_quant_llm_linear_correctness( - self, ebits, mbits, BS, OC, IC, splitK, dtype - ): - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py - floatx_weight, scale, fp16_act = self._create_floatx_inputs( - ebits, mbits, BS, OC, IC, "cuda", dtype - ) - - results_floatx = torchao.ops.quant_llm_linear( - ebits, mbits, fp16_act, floatx_weight, scale, splitK - ) - - fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to( - dtype - ) - results_fp16 = fp16_act @ fp16_weight.T - - error = (results_floatx - results_fp16).abs().mean() - gt = results_fp16.abs().mean() - relative_error = error / gt - rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 - assert relative_error < rtol def _scaled_dot_product_int8_op_ref( self, diff --git a/torchao/csrc/cuda/fp6_llm/README.md b/torchao/csrc/cuda/fp6_llm/README.md deleted file mode 100644 index 8df1fb1416..0000000000 --- a/torchao/csrc/cuda/fp6_llm/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# FP6-LLM kernel - -This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). - -On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. - -See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h deleted file mode 100644 index 54d4db3a69..0000000000 --- a/torchao/csrc/cuda/fp6_llm/configs.h +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h - -#ifndef CONFIGS_H -#define CONFIGS_H - -//#define DEBUG_MODE -#define PIPELINE_LEVEL_GMEM 2 -#define PIPELINE_LEVEL_SMEM 2 // only support 2 - -/************************ Hardware Parameters ************************/ -#define WARP_SIZE 32 -#define REG_BIT_WIDTH 32 -// mma: M=16 K=16 N=8 -#define MMA_8 8 -#define MMA_16 16 -// for memory access -#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... -#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 - -/******************** Register Allocation For GEMM ********************/ -#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation -/********************** Memory Padding Parameters **********************/ -// Eliminating bank-conflict -#define PADDING_BYTES_16 16 // Padding 16 bytes each column -#define PADDING_SHARED_MEM_FOR_B_8 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B -#define PADDING_SHARED_MEM_FOR_C_4 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C -/************************* WARP Tiling part-1 *************************/ -#define WARP_ROW_MMA_TENSORS 4 -#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 -#define WARP_K_MMA_TENSORS 4 -#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 -template -struct TilingConfig { - // Depending on "n" dimension of the GEMM - static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; - static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; - static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; - /************************* WARP Tiling part-2 *************************/ - static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; - /*************************Thread Block Tiling *************************/ - static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; - static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; - static constexpr int TILE_K = WARP_K; - /********************** #Thread per Thread Block **********************/ - static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; - static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; - /******************************* Others *******************************/ - static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 - static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 -}; - - - -#endif // CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu deleted file mode 100644 index 26f6494220..0000000000 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ /dev/null @@ -1,293 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu -// -// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): -// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory -// - Added proper architecture check at both host and device level -// - - -#include "kernel_matmul.cuh" -#include "kernel_reduction.cuh" - -#include -#include - -#include -#include -#include -#include - - -// https://github.com/Dao-AILab/flash-attention/blob/478ee666cccbd1b8f63648633003059a8dc6827d/hopper/utils.h#L25 -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while(0) - -#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) - - -template -static void Kernel_Ex(cudaStream_t stream, - const uint4 *Weight, - const half *Scales, - const half *B, - OutputDataType *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - int Split_K) -{ - #ifdef DEBUG_MODE - printf("\n"); - printf("Launcher.cu->Kernel_Ex():\n"); - printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); - printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); - #endif - static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); - size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; - size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; - dim3 GridDim(dimN, dimM, 1); - dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); - // - #ifdef DEBUG_MODE - printf("GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: %d SHMEM_SZ: %d\n", - GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); - printf("\n"); - #endif - QUANT_GEMM_Kernel<<>> - (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - CHECK_CUDA_KERNEL_LAUNCH(); -} - -template -void fpx_linear_kernel(cudaStream_t stream, - const uint4 *Weight, - const half *Scales, - const half *B, - InputDataType *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - int Split_K) -{ - static_assert(std::is_same::value || std::is_same::value, "Type must be 'half' or '__nv_bfloat16'"); - assert(M_Global % 256 == 0); - assert(K_Global % 64 == 0); - assert(N_Global > 0); - - // Check GPU Compute Capability before proceeding - int device, major, minor; - CHECK_CUDA(cudaGetDevice(&device)); - CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - - // Early exit with error for unsupported architectures - if ((major < 7) || (major == 7 && minor < 5)) { - TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " - "Your current device has SM", major, minor, " which is not supported."); - } - - const bool is_sm75_gpu = (major == 7) && (minor == 5); - if (is_sm75_gpu && std::is_same::value) { - TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); - } - - // Work around to support more N shapes: - size_t N_PowerOf2; - if(N_Global>0 && N_Global<=8) N_PowerOf2 = 8; - if(N_Global>8 && N_Global<=16) N_PowerOf2 = 16; - if(N_Global>16 && N_Global<=32) N_PowerOf2 = 32; - if(N_Global>32 && N_Global<=64) N_PowerOf2 = 64; - if(N_Global>64 && N_Global<=128) N_PowerOf2 = 128; - if(N_Global>128) N_PowerOf2 = ((N_Global-1)/128+1) * 128; - - if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0)) { - // For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory. - if (Split_K == 1) { - Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - } else { - Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); - } - } else { - if (Split_K == 1) { - switch (N_PowerOf2) { - case 8: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); - } - Kernel_Ex, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - } - } - else { - switch (N_PowerOf2) { - case 8: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - default: if (N_PowerOf2 % 128 != 0) { - TORCH_CHECK(false, "Quant-LLM Error: Unsupported N dimension ", N_PowerOf2); - } - Kernel_Ex, InputDataType, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - } - } - } - - if (Split_K != 1) { - // Reduction for SplitK - dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); - dim3 BlockDim(WARP_SIZE, 1, 1); - SplitK_Reduction<<>>(C, Reduction_Workspace, M_Global, N_Global, Split_K); - CHECK_CUDA_KERNEL_LAUNCH(); - } -} - - -// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using torch_t = at::Half; \ - using nv_t = half; \ - __VA_ARGS__(); \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using torch_t = at::BFloat16; \ - using nv_t = __nv_bfloat16; \ - __VA_ARGS__(); \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -namespace torchao { -// MODIFICATION NOTE: dtype of _weights is changed to uint8 -/* -Computes FPx-FP16 GEMM (PyTorch interface). - -[Mathmatical Formula] -Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. -After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when calling our CUDA kernel. - -[Inputs] - _in_feats: tensor of shape [B, IC]; // half or bf16 - _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. - _scales: tensor of shape [OC]; // half or bf16 - splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. -[Outputs] - _out_feats: tensor of shape [B, OC]; // half or bf16 -*/ -torch::Tensor fp_eXmY_linear_forward_cuda( - int64_t EXPONENT, - int64_t MANTISSA, - torch::Tensor _in_feats, - torch::Tensor _weights, - torch::Tensor _scales, - int64_t splitK=1) -{ - // Check GPU Compute Capability before proceeding - int device, major, minor; - CHECK_CUDA(cudaGetDevice(&device)); - CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); - CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); - - // Early exit with error for unsupported architectures - if ((major < 7) || (major == 7 && minor < 5)) { - TORCH_CHECK(false, "Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. " - "Your current device has SM", major, minor, " which is not supported."); - } - - const bool is_sm75_gpu = (major == 7) && (minor == 5); - if (is_sm75_gpu && _in_feats.scalar_type() == at::ScalarType::BFloat16) { - TORCH_CHECK(false, "Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs."); - } - - const int64_t NBITS = 1 + EXPONENT + MANTISSA; - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - int num_out_channels = _weights.size(0); - TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); - TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. - // - int M = num_out_channels; - int K = num_in_channels; - int N = num_in_feats; - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); - - options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); - at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); - auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - - // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default stream (0) - // this fixes problem with CUDA graphs when used with torch.compile() - auto stream = at::cuda::getCurrentCUDAStream(); - - DISPATCH_HALF_AND_BF16(_in_feats.scalar_type(), "fpx_linear_kernel", [&] { - auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto scales = reinterpret_cast(_scales.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - - // officially supported in Quant-LLM - if (EXPONENT == 3 && MANTISSA == 2) - fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 2 && MANTISSA == 2) - fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - // experimental - else if (EXPONENT == 2 && MANTISSA == 3) - fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - else if (EXPONENT == 3 && MANTISSA == 1) - fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 1) - // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 3 && MANTISSA == 0) - // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - // else if (EXPONENT == 2 && MANTISSA == 0) - // fpx_linear_kernel(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); - - else - TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); - }); - - return _out_feats; -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); -} - -} // namespace torchao diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh deleted file mode 100644 index 096bdc0d7f..0000000000 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh -// -// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): -// - Added __CUDA_ARCH__ guards such that async operations are only executed for SM80 and up -// - -#include "configs.h" -#include "utils_gmem.cuh" -#include "utils_core.cuh" - -/************************** Bitwidth of Weight Segments ************************/ -#define BIT_WIDTH_1 1 -#define BIT_WIDTH_2 2 -#define BIT_WIDTH_4 4 -/*************************** 64*64 Weghts of Weight Matrix *********************/ -#define WEIGHT_PER_WARP (WARP_M*WARP_K) // 64*64 = 4096 -#define SMEM_SIZE_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/8) // 512 Bytes, doubleBuffer not taken into consideration -#define SMEM_SIZE_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/8) // 1024 Bytes, doubleBuffer not taken into consideration -#define SMEM_SIZE_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/8) // 2048 Bytes, doubleBuffer not taken into consideration -#define SMEM_SIZE_PER_TB_1BIT (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB; double buffer for 2-level pipeline A= 4 KB. -#define SMEM_SIZE_PER_TB_2BIT (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. -#define SMEM_SIZE_PER_TB_4BIT (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. -#define SMEM_SIZE_PER_TB_A_TILE (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex(). -/******************** Gloabl Memory Layout For QUANTIZED DATA *******************/ -#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/128) // 32 -#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/128) // 64 -#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/128) // 128 - -/* - * C = A*B - * A: row major with ahead-of-time layout transformation, FP6 - * B: col major, FP16 - * C: col major, FP16 - */ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, - const half *B, - OutputDataType* C, - const size_t M_Global, const size_t N_Global, const size_t K_Global, - int Split_K) -{ - #ifdef DEBUG_MODE - assert(K_Global%TilingConfig::TILE_K==0); - assert(M_Global%TilingConfig::TILE_M==0); - assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); - #endif - // 1+2+4 weight split - constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; - constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; - constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; - constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; - const uint4* Weight_1bit = Weight; - const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); - const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); - // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned - extern __shared__ __align__(128) half smem[]; - half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles - __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes - // Thread Block Mapping, considering SplitK - const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); - const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) - const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) - const size_t Tile_Start_M = y * TilingConfig::TILE_M; - const size_t Tile_Start_N = x * TilingConfig::TILE_N; - const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; - const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; - const size_t AverageNumBlock_K = NumBlock_K/Split_K; - const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; - size_t NumIter = AverageNumBlock_K; - size_t StartBlockID_K = AverageNumBlock_K*BatchID; - if(BatchID(smem); - uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4; - uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4; // 8 buffers including double buffers, 12 for trible buffers - // StartSPTR for each WARP - AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4; - AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4; - AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; - // Pre-fetch of A tile - for(int i=0; i(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); - WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; - WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; - WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; - } - // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; - const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; - CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); - // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// - const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; - for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); - BTile_GPTR += TilingConfig::TILE_K; - } - // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// - constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block - uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 - uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) - float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; - for(int i=0; i= 800 - cp_async_wait_all(); - #endif - __syncthreads(); - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales - ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); - // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// - constexpr bool USE_BF16 = std::is_same::value; - initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); - // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - #pragma unroll(1) - for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) - { - // Trible-Buffer for A Tile - uint32_t* __restrict__ read_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; - uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; - uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; - uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - // Trible-Buffer for B Tile - // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. - half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - // - bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; - // Copying A tile from Global to Register, Bypassing L1, using double-buffer - if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); - if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); - if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); - // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); - #if __CUDA_ARCH__ >= 800 - cp_async_group_commit(); - #endif - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); - // Barriers and Synchronizations - #if __CUDA_ARCH__ >= 800 - cp_async_wait_group(); - #endif - __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); - // Updating global PTRs - WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 - BTile_GPTR += TilingConfig::TILE_K; - } - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Store the C fragments to shared memory. - float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = - reinterpret_cast (smem); - StoreToSharedMemoryFromRegister(smem_CFrag, c); - __syncthreads(); - // Now that shared memory contains all the D tiles, stream them to global memory. - OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; - for(size_t i=warpId; i::value) { - BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); - } else if constexpr (std::is_same::value) { - #if __CUDA_ARCH__ >= 800 - BlockGlobalPTR[j+i*M_Global] = __float2bfloat16_rn(smem_CFrag[i][j]); - #endif - } else { - BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; - } - } -} -#else -// Stub implementation for older architectures -template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, - const half *B, - OutputDataType* C, - const size_t M_Global, const size_t N_Global, const size_t K_Global, - int Split_K) -{ -// NOOP, should never actually be called -} -#endif diff --git a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh b/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh deleted file mode 100644 index 5dc4d02e77..0000000000 --- a/torchao/csrc/cuda/fp6_llm/kernel_reduction.cuh +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_reduction.cuh - -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -// Used for the reduction of result matrix if Split-K is used -// Reduction_Workspace: (Split_K, M_Global, N_Global), column major -// C: (M_Global, N_Global), column major -// Each thread deals with 8 output elements, each elements is the sum of Split_K elements -// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp -// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp -// GridSize = (M_Global*N_Global) / 256 - -#include -#include -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 -#include -#endif -#include - -#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 -#define HALF_PER_128BIT 8 - -template -__global__ void SplitK_Reduction(T* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) -{ - T* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - T* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; - float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; - // Initializing Thread-Local Results - float Results[HALF_PER_128BIT]; - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; - // Reduction - for (int i = 0; i < Split_K; i++) { - #pragma unroll - for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; - THREAD_GPTR_R += M_Global * N_Global; - } - // Writing to global memory - if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); - } else { // __nv_bfloat16> - #if __CUDA_ARCH__ >= 800 - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2bfloat16_rn(Results[i]); - #endif - } -} diff --git a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh b/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh deleted file mode 100644 index a99d7acb5c..0000000000 --- a/torchao/csrc/cuda/fp6_llm/ptx_cp.async.cuh +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_cp.async.cuh - -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -// Extended from CUTLASS's source code - -#ifndef PTX_CP_ASYNC_CUH -#define PTX_CP_ASYNC_CUH - -#include -#include -#include - -template -__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) -{ - static_assert(SizeInBytes == 16, "Size is not supported"); - unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); - asm volatile("{ \n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred_guard), - "r"(smem_int_ptr), - "l"(global_ptr), - "n"(SizeInBytes)); -} - -/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. -__device__ __forceinline__ void cp_async_group_commit() -{ - asm volatile("cp.async.commit_group;\n" ::); -} - -/// Blocks until all but previous cp.async.commit_group operations have committed. -template -__device__ __forceinline__ void cp_async_wait_group() -{ - asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); -} - -/// Blocks until all previous cp.async.commit_group operations have committed. -// cp.async.wait_all is equivalent to : -// cp.async.commit_group; -// cp.async.wait_group 0; -__device__ __forceinline__ void cp_async_wait_all() -{ - asm volatile("cp.async.wait_all;\n" ::); -} - -#endif diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh deleted file mode 100644 index d5f937e207..0000000000 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh -// -// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): -// - Replaced m16n8k16 Tensor core operation with two m16n8k8 operations -// - Accounted for a difference in expected parameters for the ldmatrix operation - -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -#ifndef PTX_MMA_CUH -#define PTX_MMA_CUH - -#include -#include -#include - -#include -#include "configs.h" - -// MODIFICATION NOTE: to support MSVC -// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] -// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) -template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], - half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int slice_id) { - #ifdef DEBUG_MODE - static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); - #endif - - const int warpId = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; - int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory - #ifdef DEBUG_MODE - assert( warp_start_col==0 ); - #endif - - #if __CUDA_ARCH__ == 750 - if (TilingConfig::WARP_COL_MMA_TENSORS==1) { - // For .target sm_75, all threads must contain valid addresses for the 'ldmatrix' op. below. Otherwise, the behavior is undefined. - // See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-load-instruction-ldmatrix - // To avoid this, we make threads 16-32 point to the same smem addresses as threads 0-15 by changing the lane id. - lane_id = lane_id % 16; - } - #endif - int col = (lane_id%8) + (lane_id/16)*8; - int row = (lane_id%16) / 8 * 8; - uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" - : "=r"(Reg[0][0]), "=r"(Reg[0][1]) - : "r"(smem_local_ptr)); - } - else { - #pragma unroll - for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) - { - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) - : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); - } - } -} - -// MODIFICATION NOTE: to support MSVC, the function signature is changed from -// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). -template -__device__ __forceinline__ void -MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) -{ - #if __CUDA_ARCH__ == 750 - // m16n8k16 op. requires >=sm_80, so instead we use two m16n8k8 ops. - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5}," - "{ %6 }," - "{ %7, %8, %9, %10 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), - "r"(b[0]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5}," - "{ %6 }," - "{ %7, %8, %9, %10 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[2]), "r"(a[3]), - "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - - #else - if constexpr (USE_BF16) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5, %6, %7 }," - "{ %8, %9 }," - "{ %10, %11, %12, %13 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } else { // FP16 - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5, %6, %7 }," - "{ %8, %9 }," - "{ %10, %11, %12, %13 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); - } - #endif -} - -#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh deleted file mode 100644 index 24231d3f88..0000000000 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh - -#ifndef UTILS_CORE_CUH -#define UTILS_CORE_CUH - -#include - -#include "configs.h" -#include "ptx_mma.cuh" -#include "utils_parallel_dequant.cuh" - - -template -__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { - SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); - int lane_id = threadIdx.x % WARP_SIZE; - #pragma unroll - for(int i=0; i -__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], - uint32_t (*b)[4], - uint32_t* __restrict__ A_1BIT_SPTR_read, - uint32_t* __restrict__ A_2BIT_SPTR_read, - uint32_t* __restrict__ A_4BIT_SPTR_read, - half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* RPTR_Scales) -{ - // 1+2+4 weight split - constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; - constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; - constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; - constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; - // Writing registers - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1bit[1]; // NO double buffer - uint32_t a_2bit[2]; // NO double buffer - uint32_t a_4bit[4]; // NO double buffer - if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); - if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); - if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers -} - -// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template -__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], - uint32_t (*a)[4], - uint32_t (*b)[4], - uint32_t* __restrict__ A_1bit_SPTR_read, - uint32_t* __restrict__ A_2bit_SPTR_read, - uint32_t* __restrict__ A_4bit_SPTR_read, - half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* RPTR_Scales, - int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching -{ - // 1+2+4 weight split - constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; - constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; - constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; - constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; - - #ifdef DEBUG_MODE - assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block - #endif - const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block - uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results - - // Setting RPTRs for double buffers - uint32_t (*a_read )[4] = a; - uint32_t (*a_write)[4] = a; - uint32_t (*b_read )[4] = b; - uint32_t (*b_write)[4] = b; - if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} - else { b_read += NumRegSets_b; a_read += NumRegSets_a;} - - // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) - #pragma unroll - for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); - } - else { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 - } - } - } - // Writing registers - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1bit[1]; // NO double buffer - uint32_t a_2bit[2]; // NO double buffer - uint32_t a_4bit[4]; // NO double buffer - if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); - if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); - if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers -} - -template -__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], - float c[][REG_PER_THREAD_C_TENSOR_16_16]) -{ - const int lane_id = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); - #pragma unroll - for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor - int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; - int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); - int Tensor_row_offset = warp_row_offset + i * MMA_16; - int Tensor_col_offset = j * MMA_8; - #pragma unroll - for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { - int row_offset = lane_id / 4; - if (r >= 2) row_offset += 8; - int col_offset = (lane_id % 4) * 2; - if (r%2==1) col_offset += 1; - smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; - } - } - } -} - -#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh deleted file mode 100644 index 9de9250299..0000000000 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh -// -// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942): -// - Replaced asynchronous copy operations with vectorized loads -// - -#ifndef UTILS_GMEM_CUH -#define UTILS_GMEM_CUH - -#include -#include "configs.h" -#include "ptx_cp.async.cuh" - -/* - * Copying A1/A2 from global memory to shared memory. - * Usually 1024 or 2048 Bytes - */ -template -__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, - const uint4* GPTR, - bool pred_guard = true) { - #ifdef DEBUG_MODE - static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); - #endif - int lane_id = threadIdx.x % WARP_SIZE; - half* SPTR_HALF = reinterpret_cast(SPTR); - const half* GPTR_HALF = reinterpret_cast(GPTR); - SPTR_HALF += lane_id*8; - GPTR_HALF += lane_id*8; - #pragma unroll - for(int i=0; i(SPTR_HALF); - const float4* GPTR_VEC = reinterpret_cast(GPTR_HALF); - SPTR_VEC[0] = GPTR_VEC[0]; - } - #else - cp_async<16>( SPTR_HALF, GPTR_HALF, pred_guard); - #endif - SPTR_HALF += 256; // Forward 512 Bytes - GPTR_HALF += 256; // Forward 512 Bytes - } - -} - -/* - * Copying 64 Quant Scales (FP16) from global memory to shared memory. - */ -__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, - const half* GPTR_A_Scales) { - int lane_id = threadIdx.x % WARP_SIZE; - int Offset_Shared = lane_id*2; - int Offset_Global = lane_id/4 + (lane_id%4)*16; - for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; -} - -// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -/* - * (1) Copying X rows * 64 columns of FP16 values, originally in row major - * (2) Copying 64 rows * X columns of FP16 values, originally in column major - * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads - */ -template -__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - const half* GlobalPTR, - const int GlobalStride, - const int NumOfLinesLeft, // To support arbitrary N dimensions. - bool Pred = true) { - // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time - const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; - const int NumOfGroups = NumOfThreads / 8; - const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; - // runtime variables - const int line_id = threadIdx.x / 8; - const int line_offset = (threadIdx.x%8) * 8; - // PTR for source global memory and target shared memory - GlobalPTR += line_id * GlobalStride + line_offset; - SharedPTR += line_id; - #pragma unroll - for (int i = 0; i < MaxIteration; i++) { - bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; - #if __CUDA_ARCH__ == 750 - if (AsyncCopyPred) { - float4* SharedPtrVec = reinterpret_cast(&(*SharedPTR)[line_offset]); - const float4* GlobalPtrVec = reinterpret_cast(GlobalPTR); - SharedPtrVec[0] = GlobalPtrVec[0]; - } - #else - cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); - #endif - GlobalPTR += NumOfGroups * GlobalStride; - SharedPTR += NumOfGroups; - } -} - -#endif diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh deleted file mode 100644 index 63afa0694c..0000000000 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. -// Copyright 2024 FP6-LLM authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh -// To support MSVC, all instances of u_int32_t are changed to uint32_t. - -#ifndef UTILS_PARALLELDEQUANT_CUH -#define UTILS_PARALLELDEQUANT_CUH - -#include -#include -#include -#include - -/* - * Input: R1 - * Outputs: R1, R2 - * Note: Simplified Exponent calculation is applied. - */ -template -__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { - // - constexpr int RIGHT_SHIFT = USE_BF16 ? 8 - EXPONENT : 5 - EXPONENT; - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | MASK3 >> 16; - // - *Out1 = *In & 0x80008000; - *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT; - // - *In = (*In) << 8; - *Out2 = *In & 0x80008000; - *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; -} - -template -__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { - constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); - constexpr int BIAS = int(1) << BIAS_OFFSET; - // - half* FP16_1 = reinterpret_cast(&PackedFP16Pair); - half* FP16_2 = FP16_1 + 1; - uint32_t output; - half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); - return output; -} - -constexpr float power_of_two(int n) { - return (n == 0) ? 1.0f : 2.0f * power_of_two(n - 1); -} - -template -__device__ __forceinline__ uint32_t MultScale(uint32_t PackedBF16Pair, __nv_bfloat16 Scale) { -#if __CUDA_ARCH__ >= 800 - constexpr int BIAS_OFFSET = (int(1) << (8-1)) - (int(1) << (EXPONENT-1)); - constexpr float BIAS = power_of_two(BIAS_OFFSET); - __nv_bfloat16* BF16_1 = reinterpret_cast<__nv_bfloat16*>(&PackedBF16Pair); - __nv_bfloat16* BF16_2 = BF16_1 + 1; - uint32_t output; - __nv_bfloat16* output_bf16_ptr = reinterpret_cast<__nv_bfloat16*>(&output); - output_bf16_ptr[0] = __hmul( __hmul(*BF16_1,__float2bfloat16(BIAS)), Scale); - output_bf16_ptr[1] = __hmul( __hmul(*BF16_2,__float2bfloat16(BIAS)), Scale); - return output; -#endif -} - -// MODIFICATION NOTE: to support MSVC -// - u_int32_t __restrict__ Reg[][4] is changed to below. -// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for read_RPTR_2bit and read_RPTR_4bit -template -__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], - uint32_t * __restrict__ read_RPTR_1bit, - uint32_t * __restrict__ read_RPTR_2bit, - uint32_t * __restrict__ read_RPTR_4bit, - uint32_t * Scales) { - // 1+2+4 weight split - constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; - constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; - constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; - constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; - // - uint32_t *OutputRegs = reinterpret_cast (Reg); - uint32_t *Frag_PTR_1bit = read_RPTR_1bit; - uint32_t *Frag_PTR_2bit = read_RPTR_2bit; - uint32_t *Frag_PTR_4bit = read_RPTR_4bit; - using scalar_t = typename std::conditional::type; - scalar_t *Scale_RPTR = reinterpret_cast(Scales); - // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 - #pragma unroll(8) - for(int i=0; i<8; i++) { - uint32_t Packed_FP6 = 0; - uint32_t tmp = 0; - // 1bit Frag - if(USE_SEG_1BIT) { - tmp = (*Frag_PTR_1bit) & 0x80808080; - Packed_FP6 |= tmp >> (BIT_WIDTH & 0); - if(i%8==7) Frag_PTR_1bit++; - else (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; - } - // 2bit Frag - if(USE_SEG_2BIT) { - tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; - Packed_FP6 |= tmp >> (BIT_WIDTH & 1); - if(i%4==3) Frag_PTR_2bit++; - else (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; - } - // 4bit Frag2 - if(USE_SEG_4BIT) { - tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; - Packed_FP6 |= tmp >> (BIT_WIDTH & 3); - if(i%2==1) Frag_PTR_4bit++; - else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; - } - uint32_t out1, out2; - FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); - // - *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16/BF16 scales - OutputRegs += 1; - *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16/BF16 scales - OutputRegs += 1; - // Updating offset for FP16/BF16 scales for every two iterations - if(i%2==1) Scale_RPTR += 2; - } - -} - -/* - * - */ -__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { - int lane_id = threadIdx.x % WARP_SIZE; - uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); - uint32_t tmpReg = SPTR_uint[lane_id]; - #pragma unroll - for(int i=0; i<4; i++) { - // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); - Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); - } -} - -#endif diff --git a/torchao/ops.py b/torchao/ops.py index 6748565fe4..a138755b8e 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -10,9 +10,6 @@ from torch import Tensor lib = torch.library.Library("torchao", "FRAGMENT") -lib.define( - "quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor" -) lib.define( "unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor" ) @@ -100,81 +97,6 @@ def cached_compute_capability(): return compute_capability -def quant_llm_linear( - EXPONENT: int, - MANTISSA: int, - _in_feats: Tensor, - _weights: Tensor, - _scales: Tensor, - splitK: int = 1, -) -> Tensor: - """ - Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. - - Arguments - EXPONENT: number of exponent bits - MANTISSA: number of mantissa bits - _in_feats: input activations in FP16 - _weights: packed Floatx weights - _scales: scale - splitK: split K - - Returns - output of linear layer - """ - # Check if we're on a supported architecture (sm7.5 or higher) - compute_capability = cached_compute_capability() - torch._check( - compute_capability >= 75, - lambda: f"quant_llm_linear requires sm7.5+ GPU architecture, but current device has sm{compute_capability}", - ) - return torch.ops.torchao.quant_llm_linear.default( - EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK - ) - - -@register_custom_op("torchao::quant_llm_linear") -def _( - EXPONENT: int, - MANTISSA: int, - _in_feats: Tensor, - _weights: Tensor, - _scales: Tensor, - splitK: int = 1, -) -> Tensor: - torch._check( - _in_feats.dim() == 2, - lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D", - ) - torch._check( - _in_feats.dtype in (torch.float16, torch.bfloat16), - lambda: f"weight must be FP16 or BF16, got {_in_feats.dtype}", - ) - torch._check( - _weights.dim() == 2, - lambda: f"weight should be a 2d tensor, got {_weights.dim()}D", - ) - torch._check( - _weights.dtype is torch.uint8, - lambda: f"weight must be UINT8, got {_weights.dtype}", - ) - torch._check( - _scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D" - ) - torch._check( - _scales.dtype in (torch.float16, torch.bfloat16), - lambda: f"scale must be FP16 or BF16, got {_scales.dtype}", - ) - - BS, IC = _in_feats.shape - OC, _ = _weights.shape - N_BITS = 1 + EXPONENT + MANTISSA - torch._check(IC // 8 * N_BITS == _weights.shape[1], lambda: "Dimensions mismatched") - torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") - - return _in_feats.new_empty((BS, OC)) - - def qscaled_dot_product( query: Tensor, key: Tensor,