-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][tosa] Add support for CONV2D_BLOCK_SCALED operator #172294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5. This includes:
Patch is 70.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172294.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index e23827f8aabf2..e452723f193b9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -512,6 +512,13 @@ extensionComplianceMap = {
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, bf16T, fp32T, bf16T},
SpecificationVersion::V_1_0}}}}},
+ {"tosa.conv2d_block_scaled",
+ {{{Extension::mxfp_conv},
+ {{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.conv3d",
{{{Extension::int4},
{{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index cc23955f31f23..421abc939b2e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -241,6 +241,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
// DYNAMIC : Removes all Compile Time Constant state for CTC inputs.
// MXFP : Microscaling formats.
+// MXFP_CONV : Microscaling format convolution.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -274,6 +275,7 @@ def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;
def Tosa_EXT_INT64 : I32EnumAttrCase<"int64", 13>;
+def Tosa_EXT_MXFP_CONV : I32EnumAttrCase<"mxfp_conv", 14>;
def Tosa_ExtensionAttr
@@ -281,16 +283,16 @@ def Tosa_ExtensionAttr
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
- Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64
+ Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP, Tosa_EXT_INT64, Tosa_EXT_MXFP_CONV,
]> {
let extraClassDeclaration = [{
- static llvm::SmallVector<Extension, 13> getAllValues() {
+ static llvm::SmallVector<Extension, 14> getAllValues() {
return {
Extension::int16, Extension::int4, Extension::bf16,
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
Extension::inexactround, Extension::dynamic, Extension::mxfp,
- Extension::int64
+ Extension::int64, Extension::mxfp_conv
};
}
}];
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 370ce8c161d0b..edd8f0fc266bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -163,6 +163,43 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: conv2d_block_scaled
+//===----------------------------------------------------------------------===//
+def Tosa_Conv2DBlockScaledOp : Tosa_InferShapedTypeOp<"conv2d_block_scaled"> {
+ let summary = "Performs two dimensional convolution using block scaled tensors.";
+
+ let description = [{
+ Performs a 2D convolution over the given input data and scales, using
+ the weight data and scales. Implementations may choose to skip calculation
+ of multiplies in the padding area.
+ }];
+
+ let arguments = (ins
+ Tosa_MXFPDataTensor4D:$input_data,
+ Tosa_MXFPScaleTensor4D:$input_scale,
+ Tosa_MXFPDataTensor4D:$weight_data,
+ Tosa_MXFPScaleTensor4D:$weight_scale,
+ Tosa_Tensor1D:$bias,
+ Rank4TosaShape:$pad,
+ Rank2TosaShape:$stride,
+ Rank2TosaShape:$dilation,
+ Tosa_BlockSizeAttr:$block_size
+ );
+
+ let results = (outs
+ Tosa_Tensor4D:$output
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_MXFP_CONV]>,
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index ea58f49b64c44..5c77bd701e416 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -149,6 +149,7 @@ class TosaProfileCompliance {
case Extension::fp8e5m2:
case Extension::fft:
case Extension::mxfp:
+ case Extension::mxfp_conv:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 266a9e3a7d946..0468ca29e10ac 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,6 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
+def Tosa_IndexTensor1D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [1]>]>;
def Tosa_IndexTensor2D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
@@ -216,6 +218,14 @@ def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;
+def Tosa_MXFPDataTensor4D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPNumber], [4]>
+]>;
+def Tosa_MXFPScaleTensor4D : AnyTypeOf<[
+ TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
+ TosaTensorRankOf<[Tosa_MXFPScaleNumber], [4]>
+]>;
def Tosa_MXFPDataTensorAtLeast1D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
TosaRankedTensorOf<[Tosa_MXFPNumber], [AtLeastRankOne]>],
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85cf9b0b..2e0a0d85d7dbe 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,7 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
return TosaSpecificationVersion(1, 0);
case Extension::mxfp:
case Extension::int64:
+ case Extension::mxfp_conv:
return TosaSpecificationVersion(1, 1);
case Extension::none:
return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index bead774620a4f..6382c28ed4312 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -550,6 +550,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -612,6 +621,55 @@ unsigned mlir::tosa::getBitWidth(Type type) {
return type.getIntOrFloatBitWidth();
}
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+ const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return op->emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+}
+
+LogicalResult verifyConvOutputSize(
+ Operation *op, const int64_t inputSize, const int64_t kernelSize,
+ const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+ const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+ const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+ const llvm::StringRef padAfterName) {
+ if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+ return success();
+
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+ const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+ stride);
+ if (!calculatedOutSizeMinusOne.has_value())
+ return op->emitOpError("expected input_")
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+ << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+ << dimAxis << " to be wholly divisible by stride_" << dimAxis
+ << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+ << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+ << ") / " << stride;
+
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+ return op->emitOpError("calculated output ")
+ << dimName << " did not match expected: "
+ << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -791,53 +849,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (inputType && weightType) {
- const auto verifyOutputSize =
- [&op](const int64_t inputSize, const int64_t kernelSize,
- const int64_t outputSize, const int64_t padBefore,
- const int64_t padAfter, const int64_t stride,
- const int64_t dilation, const llvm::StringRef dimName,
- const llvm::StringRef dimAxis,
- const llvm::StringRef padBeforeName,
- const llvm::StringRef padAfterName) -> LogicalResult {
- if (inputSize == ShapedType::kDynamic ||
- kernelSize == ShapedType::kDynamic)
- return success();
-
- // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
- const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
- stride);
- if (!calculatedOutSizeMinusOne.has_value())
- return op.emitOpError("expected input_")
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
- << padAfterName << " - (kernel_" << dimName
- << " - 1) * dilation_" << dimAxis
- << " to be wholly divisible by stride_" << dimAxis << ", got ("
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
- << " - (" << kernelSize << " - 1) * " << dilation << ") / "
- << stride;
-
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
- return op.emitOpError("calculated output ")
- << dimName << " did not match expected: "
- << "calculated=" << calculatedOutSize
- << ", expected=" << outputSize;
-
- return success();
- };
-
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -845,14 +866,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(0),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(0),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(1),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -860,20 +881,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "depth", "d", "front", "back")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(3), weightType.getDimSize(3),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(3), weightType.getDimSize(3),
outputType.getDimSize(3), padding[4], padding[5], strides[2],
dilations[2], "width", "x", "left", "right")))
return failure();
@@ -1954,20 +1975,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
"B_data")))
return failure();
- auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
- const StringRef operandName,
- const StringRef dimName) -> LogicalResult {
- if (ShapedType::isDynamic(currDim)) {
- currDim = newDim;
- return success();
- } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
- return emitOpError("expected ")
- << dimName << " of " << operandName << " to match size " << currDim
- << ", got " << newDim;
- }
- return success();
- };
-
// Verify input shape compatibility
int64_t N = ShapedType::kDynamic;
int64_t D = ShapedType::kDynamic;
@@ -1985,32 +1992,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
if (aScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
- "height")))
+ if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+ "a_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+ "a_scale", "height")))
return failure();
multiplesOfC = aScaleShape.getDimSize(2);
}
const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
if (bDataShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
- "batch")) ||
- failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
- "channels")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+ "b_data", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+ "b_data", "channels")))
return failure();
W = bDataShape.getDimSize(1);
}
const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
if (bScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
- "width")) ||
- failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
- "b_scale", "C/block_size")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+ "b_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+ "b_scale", "width")) ||
+ failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+ bScaleShape.getDimSize(2), "b_scale",
+ "C/block_size")))
return failure();
}
@@ -3485,6 +3493,228 @@ LogicalResult Conv2DOp::verify() {
return success();
}
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+ int64_t inputWidth = ShapedType::kDynamic;
+ int64_t inputHeight = ShapedType::kDynamic;
+ int64_t weightWidth = ShapedType::kDynamic;
+ int64_t weightHeight = ShapedType::kDynamic;
+
+ // Input shape describes input width/height and batch.
+ const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ outShape[0] = inputDataShape.getDimSize(0);
+ inputHeight = inputDataShape.getDimSize(1);
+ inputWidth = inputDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+ if (inputScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0])
+ ? inputScaleShape.getDimSize(0)
+ : outShape[0];
+ inputHeight = ShapedType::isDynamic(inputHeight)
+ ? ...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5. This includes: - Operator definition - Addition of the EXT_MXFP_CONV extension - Verification logic for the operator - Output shape inference for the operator - Validation checks to ensure compliance with the TOSA specification. Change-Id: I7553f7796d2d156f43310108e9a69a593cdece33
0a30dd2 to
ab4c945
Compare
This commit adds support for an MXFP CONV2D operation, CONV2D_BLOCK_SCALED, added to the specification in arm/tosa-specification@408a5e5.
This includes: